mirror of https://github.com/llvm/circt.git
399 lines
14 KiB
C++
399 lines
14 KiB
C++
//===- DCOps.cpp ----------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Dialect/DC/DCOps.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/FunctionImplementation.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
|
|
using namespace circt;
|
|
using namespace dc;
|
|
using namespace mlir;
|
|
|
|
bool circt::dc::isI1ValueType(Type t) {
|
|
auto vt = t.dyn_cast<ValueType>();
|
|
if (!vt)
|
|
return false;
|
|
auto innerWidth = vt.getInnerType().getIntOrFloatBitWidth();
|
|
return innerWidth == 1;
|
|
}
|
|
|
|
namespace circt {
|
|
namespace dc {
|
|
|
|
// =============================================================================
|
|
// JoinOp
|
|
// =============================================================================
|
|
|
|
OpFoldResult JoinOp::fold(FoldAdaptor adaptor) {
|
|
// Fold simple joins (joins with 1 input).
|
|
if (auto tokens = getTokens(); tokens.size() == 1)
|
|
return tokens.front();
|
|
|
|
// These folders are disabled to work around MLIR bugs when changing
|
|
// the number of operands. https://github.com/llvm/llvm-project/issues/64280
|
|
return {};
|
|
|
|
// Remove operands which originate from a dc.source op (redundant).
|
|
auto *op = getOperation();
|
|
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
|
if (auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
|
|
op->eraseOperand(operand.getOperandNumber());
|
|
return getOutput();
|
|
}
|
|
}
|
|
|
|
// Remove duplicate operands.
|
|
llvm::DenseSet<Value> uniqueOperands;
|
|
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
|
if (!uniqueOperands.insert(operand.get()).second) {
|
|
op->eraseOperand(operand.getOperandNumber());
|
|
return getOutput();
|
|
}
|
|
}
|
|
|
|
// Canonicalization staggered joins where the sink join contains inputs also
|
|
// found in the source join.
|
|
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
|
auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
|
|
if (!otherJoin) {
|
|
// Operand does not originate from a join so it's a valid join input.
|
|
continue;
|
|
}
|
|
|
|
// Operand originates from a join. Erase the current join operand and add
|
|
// all of the otherJoin op's inputs to this join.
|
|
// DCE will take care of otherJoin in case it's no longer used.
|
|
op->eraseOperand(operand.getOperandNumber());
|
|
op->insertOperands(getNumOperands(), otherJoin.getTokens());
|
|
return getOutput();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
// =============================================================================
|
|
// ForkOp
|
|
// =============================================================================
|
|
|
|
template <typename TInt>
|
|
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v) {
|
|
if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
OpAsmParser::UnresolvedOperand operand;
|
|
size_t size = 0;
|
|
if (parseIntInSquareBrackets(parser, size))
|
|
return failure();
|
|
|
|
if (size == 0)
|
|
return parser.emitError(parser.getNameLoc(),
|
|
"fork size must be greater than 0");
|
|
|
|
if (parser.parseOperand(operand) ||
|
|
parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
auto tt = dc::TokenType::get(parser.getContext());
|
|
llvm::SmallVector<Type> operandTypes{tt};
|
|
SmallVector<Type> resultTypes{size, tt};
|
|
result.addTypes(resultTypes);
|
|
if (parser.resolveOperand(operand, tt, result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void ForkOp::print(OpAsmPrinter &p) {
|
|
p << " [" << getNumResults() << "] ";
|
|
p << getOperand() << " ";
|
|
auto attrs = (*this)->getAttrs();
|
|
if (!attrs.empty()) {
|
|
p << " ";
|
|
p.printOptionalAttrDict(attrs);
|
|
}
|
|
}
|
|
|
|
class EliminateForkToForkPattern : public OpRewritePattern<ForkOp> {
|
|
// Canonicalization of forks where the output is fed into another fork.
|
|
public:
|
|
using OpRewritePattern<ForkOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(ForkOp fork,
|
|
PatternRewriter &rewriter) const override {
|
|
for (auto output : fork.getOutputs()) {
|
|
for (auto *user : output.getUsers()) {
|
|
auto userFork = dyn_cast<ForkOp>(user);
|
|
if (!userFork)
|
|
continue;
|
|
|
|
// We have a fork feeding into another fork. Replace the output fork by
|
|
// adding more outputs to the current fork.
|
|
size_t totalForks = fork.getNumResults() + userFork.getNumResults() - 1;
|
|
|
|
auto newFork = rewriter.create<dc::ForkOp>(fork.getLoc(),
|
|
fork.getToken(), totalForks);
|
|
rewriter.replaceOp(
|
|
fork, newFork.getResults().take_front(fork.getNumResults()));
|
|
rewriter.replaceOp(
|
|
userFork, newFork.getResults().take_back(userFork.getNumResults()));
|
|
|
|
// Just stop the pattern here instead of trying to do more - let the
|
|
// canonicalizer recurse if another run of the canonicalization applies.
|
|
return success();
|
|
}
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
class EliminateForkOfSourcePattern : public OpRewritePattern<ForkOp> {
|
|
// Canonicalizes away forks on source ops, in favor of individual source
|
|
// operations. Having standalone sources are a better alternative, since other
|
|
// operations can canonicalize on it (e.g. joins) as well as being very cheap
|
|
// to implement in hardware, if they do remain.
|
|
public:
|
|
using OpRewritePattern<ForkOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(ForkOp fork,
|
|
PatternRewriter &rewriter) const override {
|
|
auto source = fork.getToken().getDefiningOp<SourceOp>();
|
|
if (!source)
|
|
return failure();
|
|
|
|
// We have a source feeding into a fork. Replace the fork by a source for
|
|
// each output.
|
|
llvm::SmallVector<Value> sources;
|
|
for (size_t i = 0; i < fork.getNumResults(); ++i)
|
|
sources.push_back(rewriter.create<dc::SourceOp>(fork.getLoc()));
|
|
|
|
rewriter.replaceOp(fork, sources);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<EliminateForkToForkPattern, EliminateForkOfSourcePattern>(
|
|
context);
|
|
}
|
|
|
|
LogicalResult ForkOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
// Fold simple forks (forks with 1 output).
|
|
if (getOutputs().size() == 1) {
|
|
results.push_back(getToken());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
// =============================================================================
|
|
// UnpackOp
|
|
// =============================================================================
|
|
|
|
struct EliminateRedundantUnpackPattern : public OpRewritePattern<UnpackOp> {
|
|
// Eliminates unpacks where only the token is used.
|
|
using OpRewritePattern<UnpackOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(UnpackOp unpack,
|
|
PatternRewriter &rewriter) const override {
|
|
// Is the value-side of the unpack used?
|
|
if (!unpack.getOutput().use_empty())
|
|
return failure();
|
|
|
|
auto pack = unpack.getInput().getDefiningOp<PackOp>();
|
|
if (!pack)
|
|
return failure();
|
|
|
|
// Replace all uses of the unpack token with the packed token.
|
|
rewriter.replaceAllUsesWith(unpack.getToken(), pack.getToken());
|
|
rewriter.eraseOp(unpack);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<EliminateRedundantUnpackPattern>(context);
|
|
}
|
|
|
|
LogicalResult UnpackOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
// Unpack of a pack is a no-op.
|
|
if (auto pack = getInput().getDefiningOp<PackOp>()) {
|
|
results.push_back(pack.getToken());
|
|
results.push_back(pack.getInput());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult UnpackOp::inferReturnTypes(
|
|
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
|
|
DictionaryAttr attrs, mlir::OpaqueProperties properties,
|
|
mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
|
|
auto inputType = operands.front().getType().cast<ValueType>();
|
|
results.push_back(TokenType::get(context));
|
|
results.push_back(inputType.getInnerType());
|
|
return success();
|
|
}
|
|
|
|
// =============================================================================
|
|
// PackOp
|
|
// =============================================================================
|
|
|
|
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
|
|
auto token = getToken();
|
|
|
|
// Pack of an unpack is a no-op.
|
|
if (auto unpack = token.getDefiningOp<UnpackOp>()) {
|
|
if (unpack.getOutput() == getInput())
|
|
return unpack.getInput();
|
|
}
|
|
return {};
|
|
}
|
|
|
|
LogicalResult PackOp::inferReturnTypes(
|
|
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
|
|
DictionaryAttr attrs, mlir::OpaqueProperties properties,
|
|
mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
|
|
llvm::SmallVector<Type> inputTypes;
|
|
Type inputType = operands.back().getType();
|
|
auto valueType = dc::ValueType::get(context, inputType);
|
|
results.push_back(valueType);
|
|
return success();
|
|
}
|
|
|
|
// =============================================================================
|
|
// SelectOp
|
|
// =============================================================================
|
|
|
|
class EliminateBranchToSelectPattern : public OpRewritePattern<SelectOp> {
|
|
// Canonicalize away a select that is fed only by a single branch
|
|
// example:
|
|
// %true, %false = dc.branch %sel1 %token
|
|
// %0 = dc.select %sel2, %true, %false
|
|
// ->
|
|
// %0 = dc.join %sel1, %sel2, %token
|
|
|
|
public:
|
|
using OpRewritePattern<SelectOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(SelectOp select,
|
|
PatternRewriter &rewriter) const override {
|
|
// Do all the inputs come from a branch?
|
|
BranchOp branchInput;
|
|
for (auto operand : {select.getTrueToken(), select.getFalseToken()}) {
|
|
auto br = operand.getDefiningOp<BranchOp>();
|
|
if (!br)
|
|
return failure();
|
|
|
|
if (!branchInput)
|
|
branchInput = br;
|
|
else if (branchInput != br)
|
|
return failure();
|
|
}
|
|
|
|
// Replace the select with a join (unpack the select conditions).
|
|
rewriter.replaceOpWithNewOp<JoinOp>(
|
|
select,
|
|
llvm::SmallVector<Value>{
|
|
rewriter.create<UnpackOp>(select.getLoc(), select.getCondition())
|
|
.getToken(),
|
|
rewriter
|
|
.create<UnpackOp>(branchInput.getLoc(),
|
|
branchInput.getCondition())
|
|
.getToken()});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<EliminateBranchToSelectPattern>(context);
|
|
}
|
|
|
|
// =============================================================================
|
|
// BufferOp
|
|
// =============================================================================
|
|
|
|
FailureOr<SmallVector<int64_t>> BufferOp::getInitValueArray() {
|
|
assert(getInitValues() && "initValues attribute not set");
|
|
SmallVector<int64_t> values;
|
|
for (auto value : getInitValuesAttr()) {
|
|
if (auto iValue = value.dyn_cast<IntegerAttr>()) {
|
|
values.push_back(iValue.getValue().getSExtValue());
|
|
} else {
|
|
return emitError() << "initValues attribute must be an array of integers";
|
|
}
|
|
}
|
|
return values;
|
|
}
|
|
|
|
LogicalResult BufferOp::verify() {
|
|
// Verify that exactly 'size' number of initial values have been provided, if
|
|
// an initializer list have been provided.
|
|
if (auto initVals = getInitValuesAttr()) {
|
|
auto nInits = initVals.size();
|
|
if (nInits != getSize())
|
|
return emitOpError() << "expected " << getSize()
|
|
<< " init values but got " << nInits << ".";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
// =============================================================================
|
|
// ToESIOp
|
|
// =============================================================================
|
|
|
|
LogicalResult ToESIOp::inferReturnTypes(
|
|
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
|
|
DictionaryAttr attrs, mlir::OpaqueProperties properties,
|
|
mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
|
|
Type channelEltType;
|
|
if (auto valueType = operands.front().getType().dyn_cast<ValueType>())
|
|
channelEltType = valueType.getInnerType();
|
|
else {
|
|
// dc.token => esi.channel<i0>
|
|
channelEltType = IntegerType::get(context, 0);
|
|
}
|
|
|
|
results.push_back(esi::ChannelType::get(context, channelEltType));
|
|
return success();
|
|
}
|
|
|
|
// =============================================================================
|
|
// FromESIOp
|
|
// =============================================================================
|
|
|
|
LogicalResult FromESIOp::inferReturnTypes(
|
|
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
|
|
DictionaryAttr attrs, mlir::OpaqueProperties properties,
|
|
mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
|
|
auto innerType =
|
|
operands.front().getType().cast<esi::ChannelType>().getInner();
|
|
if (auto intType = innerType.dyn_cast<IntegerType>(); intType.getWidth() == 0)
|
|
results.push_back(dc::TokenType::get(context));
|
|
else
|
|
results.push_back(dc::ValueType::get(context, innerType));
|
|
|
|
return success();
|
|
}
|
|
|
|
} // namespace dc
|
|
} // namespace circt
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "circt/Dialect/DC/DC.cpp.inc"
|