mirror of https://github.com/llvm/circt.git
[DC] Add + re-enable canonicalization patterns (#7952)
Various canonicalization patterns were implemented as fold methods. However, these had been turned off a while ago, due to an MLIR bug. This PR moves the fold methods to actual canonicalization patterns, adds a new one (join on branch), and re-enables previously disabled tests. Co-authored-by: Morten Borup Petersen <mpetersen@microsoft.com>
This commit is contained in:
parent
d2a31174cf
commit
6d894b901b
|
@ -94,6 +94,7 @@ def JoinOp : DCOp<"join", [Commutative]> {
|
|||
|
||||
let assemblyFormat = "$tokens attr-dict";
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<(ins "mlir::ValueRange":$ins), [{
|
||||
|
|
|
@ -38,46 +38,125 @@ OpFoldResult JoinOp::fold(FoldAdaptor adaptor) {
|
|||
if (auto tokens = getTokens(); tokens.size() == 1)
|
||||
return tokens.front();
|
||||
|
||||
// These folders are disabled to work around MLIR bugs when changing
|
||||
// the number of operands. https://github.com/llvm/llvm-project/issues/64280
|
||||
return {};
|
||||
}
|
||||
|
||||
// Remove operands which originate from a dc.source op (redundant).
|
||||
auto *op = getOperation();
|
||||
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
||||
if (auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
|
||||
op->eraseOperand(operand.getOperandNumber());
|
||||
return getOutput();
|
||||
}
|
||||
}
|
||||
struct JoinOnBranchPattern : public OpRewritePattern<JoinOp> {
|
||||
using OpRewritePattern<JoinOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(JoinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
// Remove duplicate operands.
|
||||
llvm::DenseSet<Value> uniqueOperands;
|
||||
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
||||
if (!uniqueOperands.insert(operand.get()).second) {
|
||||
op->eraseOperand(operand.getOperandNumber());
|
||||
return getOutput();
|
||||
}
|
||||
}
|
||||
struct BranchOperandInfo {
|
||||
// Unique operands from the branch op, in case we have the same operand
|
||||
// from the branch op multiple times.
|
||||
SetVector<Value> uniqueOperands;
|
||||
// Indices which the operands are at in the join op.
|
||||
BitVector indices;
|
||||
};
|
||||
|
||||
// Canonicalization staggered joins where the sink join contains inputs also
|
||||
// found in the source join.
|
||||
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
||||
auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
|
||||
if (!otherJoin) {
|
||||
// Operand does not originate from a join so it's a valid join input.
|
||||
continue;
|
||||
DenseMap<BranchOp, BranchOperandInfo> branchOperands;
|
||||
for (auto &opOperand : op->getOpOperands()) {
|
||||
auto branch = opOperand.get().getDefiningOp<BranchOp>();
|
||||
if (!branch)
|
||||
continue;
|
||||
|
||||
BranchOperandInfo &info = branchOperands[branch];
|
||||
info.uniqueOperands.insert(opOperand.get());
|
||||
info.indices.resize(op->getNumOperands());
|
||||
info.indices.set(opOperand.getOperandNumber());
|
||||
}
|
||||
|
||||
// Operand originates from a join. Erase the current join operand and add
|
||||
// all of the otherJoin op's inputs to this join.
|
||||
// DCE will take care of otherJoin in case it's no longer used.
|
||||
op->eraseOperand(operand.getOperandNumber());
|
||||
op->insertOperands(getNumOperands(), otherJoin.getTokens());
|
||||
return getOutput();
|
||||
}
|
||||
if (branchOperands.empty())
|
||||
return failure();
|
||||
|
||||
return {};
|
||||
// Do we have both operands from any given branch op?
|
||||
for (auto &it : branchOperands) {
|
||||
auto branch = it.first;
|
||||
auto &operandInfo = it.second;
|
||||
if (operandInfo.uniqueOperands.size() != 2) {
|
||||
// We don't have both operands from the branch op.
|
||||
continue;
|
||||
}
|
||||
|
||||
// We have both operands from the branch op. Replace the join op with the
|
||||
// branch op's data operand.
|
||||
|
||||
// Unpack the !dc.value<i1> input to the branch op
|
||||
auto unpacked =
|
||||
rewriter.create<UnpackOp>(op.getLoc(), branch.getCondition());
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
op->eraseOperands(operandInfo.indices);
|
||||
op.getTokensMutable().append({unpacked.getToken()});
|
||||
});
|
||||
|
||||
// Only attempt a single branch at a time - else we'd have to maintain
|
||||
// OpOperand indices during the loop... too complicated, let recursive
|
||||
// pattern application handle this.
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
struct StaggeredJoinCanonicalization : public OpRewritePattern<JoinOp> {
|
||||
using OpRewritePattern<JoinOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(JoinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
||||
auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
|
||||
if (!otherJoin) {
|
||||
// Operand does not originate from a join so it's a valid join input.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Operand originates from a join. Erase the current join operand and
|
||||
// add all of the otherJoin op's inputs to this join.
|
||||
// DCE will take care of otherJoin in case it's no longer used.
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
op.getTokensMutable().erase(operand.getOperandNumber());
|
||||
op.getTokensMutable().append(otherJoin.getTokens());
|
||||
});
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct RemoveJoinOnSourcePattern : public OpRewritePattern<JoinOp> {
|
||||
using OpRewritePattern<JoinOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(JoinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
||||
if (auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct RemoveDuplicateJoinOperandsPattern : public OpRewritePattern<JoinOp> {
|
||||
using OpRewritePattern<JoinOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(JoinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
llvm::DenseSet<Value> uniqueOperands;
|
||||
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
|
||||
if (!uniqueOperands.insert(operand.get()).second) {
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
void JoinOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<RemoveDuplicateJoinOperandsPattern, RemoveJoinOnSourcePattern,
|
||||
StaggeredJoinCanonicalization, JoinOnBranchPattern>(context);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
|
@ -13,12 +13,11 @@ func.func @staggeredJoin1(%a: !dc.token, %b : !dc.token) -> (!dc.token) {
|
|||
return %1 : !dc.token
|
||||
}
|
||||
|
||||
// TODO: For some reason, the canonicalizer no longer combines the two joins. Investigate.
|
||||
// CHECK-LABEL: func.func @staggeredJoin2(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !dc.token, %[[VAL_1:.*]]: !dc.token, %[[VAL_2:.*]]: !dc.token, %[[VAL_3:.*]]: !dc.token) -> !dc.token {
|
||||
// CHECKx: %[[VAL_4:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]
|
||||
// CHECKx: return %[[VAL_4]] : !dc.token
|
||||
// CHECKx: }
|
||||
// CHECK: %[[VAL_4:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]
|
||||
// CHECK: return %[[VAL_4]] : !dc.token
|
||||
// CHECK: }
|
||||
func.func @staggeredJoin2(%a: !dc.token, %b : !dc.token, %c : !dc.token, %d : !dc.token) -> (!dc.token) {
|
||||
%0 = dc.join %a, %b
|
||||
%1 = dc.join %c, %0, %d
|
||||
|
@ -102,13 +101,11 @@ func.func @forkToFork2(%a: !dc.token) -> (!dc.token, !dc.token, !dc.token) {
|
|||
return %0, %2, %3 : !dc.token, !dc.token, !dc.token
|
||||
}
|
||||
|
||||
// TODO: For some reason, the canonicalizer no longer simplifies this redundant
|
||||
// triangle pattern. Investigate.
|
||||
// CHECK-LABEL: func.func @merge(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !dc.value<i1>) -> !dc.token {
|
||||
// CHECKx: %[[VAL_1:.*]], %[[VAL_2:.*]] = dc.unpack %[[VAL_0]] : !dc.value<i1>
|
||||
// CHECKx: return %[[VAL_1]] : !dc.token
|
||||
// CHECKx: }
|
||||
// CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = dc.unpack %[[VAL_0]] : !dc.value<i1>
|
||||
// CHECK: return %[[VAL_1]] : !dc.token
|
||||
// CHECK: }
|
||||
func.func @merge(%sel : !dc.value<i1>) -> (!dc.token) {
|
||||
// Canonicalize away a merge that is fed by a branch with the same select
|
||||
// input.
|
||||
|
@ -117,20 +114,35 @@ func.func @merge(%sel : !dc.value<i1>) -> (!dc.token) {
|
|||
return %0 : !dc.token
|
||||
}
|
||||
|
||||
// TODO: For some reason, the canonicalizer no longer removes the source->join
|
||||
// pattern. Investigate.
|
||||
// CHECK-LABEL: func.func @joinOnSource(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !dc.token,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !dc.token) -> !dc.token {
|
||||
// CHECKx: %[[VAL_2:.*]] = dc.join %[[VAL_0]], %[[VAL_1]]
|
||||
// CHECKx: return %[[VAL_2]] : !dc.token
|
||||
// CHECKx: }
|
||||
// CHECK: %[[VAL_2:.*]] = dc.join %[[VAL_0]], %[[VAL_1]]
|
||||
// CHECK: return %[[VAL_2]] : !dc.token
|
||||
// CHECK: }
|
||||
func.func @joinOnSource(%a : !dc.token, %b : !dc.token) -> (!dc.token) {
|
||||
%0 = dc.source
|
||||
%out = dc.join %a, %0, %b
|
||||
return %out : !dc.token
|
||||
}
|
||||
|
||||
|
||||
// Join on branch, where all branch results are used in the join is a no-op,
|
||||
// and the join can use the token of the input value to the branch.
|
||||
// CHECK-LABEL: func.func @joinOnBranch(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !dc.value<i1>, %[[VAL_1:.*]]: !dc.value<i1>, %[[VAL_2:.*]]: !dc.token) -> (!dc.token, !dc.token) {
|
||||
// CHECK: %[[VAL_3:.*]], %[[VAL_4:.*]] = dc.branch %[[VAL_1]]
|
||||
// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = dc.unpack %[[VAL_0]] : !dc.value<i1>
|
||||
// CHECK: %[[VAL_7:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_5]]
|
||||
// CHECK: return %[[VAL_7]], %[[VAL_4]] : !dc.token, !dc.token
|
||||
// CHECK: }
|
||||
func.func @joinOnBranch(%sel : !dc.value<i1>, %sel2 : !dc.value<i1>, %other : !dc.token) -> (!dc.token, !dc.token) {
|
||||
%true, %false = dc.branch %sel
|
||||
%true2, %false2 = dc.branch %sel2
|
||||
%out = dc.join %true, %false, %other, %true2
|
||||
return %out, %false2 : !dc.token, !dc.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @forkOfSource() -> (!dc.token, !dc.token) {
|
||||
// CHECK: %[[VAL_0:.*]] = dc.source
|
||||
// CHECK: return %[[VAL_0]], %[[VAL_0]] : !dc.token, !dc.token
|
||||
|
|
Loading…
Reference in New Issue