[Sim] Flatten format string concatenations in canonicalizer (#7316)

Provide an interface to get the flat format string for sim.fmt.concat operations and opportunistically flatten during canonicalization.
This commit is contained in:
fzi-hielscher 2024-07-16 22:35:33 +02:00 committed by GitHub
parent a9ac3ae4b0
commit 911988f8e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 169 additions and 8 deletions

View File

@ -297,7 +297,34 @@ def FormatStringConcatOp : SimOp<"fmt.concat", [Pure]> {
let hasVerifier = true;
let assemblyFormat = "` ` `(` $inputs `)` attr-dict";
let extraClassDeclaration = [{
/// Returns true iff all of the input strings are primitive
/// (i.e. non-concatenated) tokens or block arguments.
bool isFlat() {
return llvm::none_of(getInputs(), [](Value operand) {
return !!operand.getDefiningOp<circt::sim::FormatStringConcatOp>();
});
};
/// Attempts to flatten this operation's input strings as much as possible.
///
/// The flattened values are pushed into the passed vector.
/// If the concatenation is acyclic, the function will return 'success'
/// and all the flattened values are guaranteed to _not_ be the result of
/// a format string concatenation.
/// If a cycle is encountered, the function will return 'failure'.
/// On encountering a cycle, the result of the concat operation
/// forming the cycle is pushed into the list of flattened values
/// and flattening continues without recursing into the cycle.
LogicalResult getFlattenedInputs(llvm::SmallVectorImpl<Value> &flatOperands);
}];
let builders = [
OpBuilder<(ins "mlir::ValueRange":$inputs), [{
return build($_builder, $_state, circt::sim::FormatStringType::get($_builder.getContext()), inputs);
}]>
];
}
#endif // CIRCT_DIALECT_SIM_SIMOPS_TD

View File

@ -12,6 +12,8 @@
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Dialect/HW/ModuleImplementation.h"
#include "circt/Dialect/SV/SVOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionImplementation.h"
@ -190,8 +192,12 @@ static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef<StringRef> lits) {
OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
if (getNumOperands() == 0)
return StringAttr::get(getContext(), "");
if (getNumOperands() == 1)
if (getNumOperands() == 1) {
// Don't fold to our own result to avoid an infinte loop.
if (getResult() == getOperand(0))
return {};
return getOperand(0);
}
// Fold if all operands are literals.
SmallVector<StringRef> lits;
@ -204,6 +210,49 @@ OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
return concatLiterals(getContext(), lits);
}
LogicalResult FormatStringConcatOp::getFlattenedInputs(
llvm::SmallVectorImpl<Value> &flatOperands) {
llvm::SmallMapVector<FormatStringConcatOp, unsigned, 4> concatStack;
bool isCyclic = false;
// Perform a DFS on this operation's concatenated operands,
// collect the leaf format string tokens.
concatStack.insert({*this, 0});
while (!concatStack.empty()) {
auto &top = concatStack.back();
auto currentConcat = top.first;
unsigned operandIndex = top.second;
// Iterate over concatenated operands
while (operandIndex < currentConcat.getNumOperands()) {
auto currentOperand = currentConcat.getOperand(operandIndex);
if (auto nextConcat =
currentOperand.getDefiningOp<FormatStringConcatOp>()) {
// Concat of a concat
if (!concatStack.contains(nextConcat)) {
// Save the next operand index to visit on the
// stack and put the new concat on top.
top.second = operandIndex + 1;
concatStack.insert({nextConcat, 0});
break;
}
// Cyclic concatenation encountered. Don't recurse.
isCyclic = true;
}
flatOperands.push_back(currentOperand);
operandIndex++;
}
// Pop the concat off of the stack if we have visited all operands.
if (operandIndex >= currentConcat.getNumOperands())
concatStack.pop_back();
}
return success(!isCyclic);
}
LogicalResult FormatStringConcatOp::verify() {
if (llvm::any_of(getOperands(),
[&](Value operand) { return operand == getResult(); }))
@ -213,11 +262,30 @@ LogicalResult FormatStringConcatOp::verify() {
LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
PatternRewriter &rewriter) {
if (op.getNumOperands() < 2)
return failure(); // Should be handled by the folder
auto fmtStrType = FormatStringType::get(op.getContext());
// Check if we can flatten concats of concats
bool hasBeenFlattened = false;
SmallVector<Value, 0> flatOperands;
if (!op.isFlat()) {
// Get a new, flattened list of operands
flatOperands.reserve(op.getNumOperands() + 4);
auto isAcyclic = op.getFlattenedInputs(flatOperands);
if (failed(isAcyclic)) {
// Infinite recursion, but we cannot fail compilation right here (can we?)
// so just emit a warning and bail out.
op.emitWarning("Cyclic concatenation detected.");
return failure();
}
hasBeenFlattened = true;
}
if (!hasBeenFlattened && op.getNumOperands() < 2)
return failure(); // Should be handled by the folder
// Check if there are adjacent literals we can merge or empty literals to
// remove
SmallVector<StringRef> litSequence;
@ -225,7 +293,8 @@ LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
newOperands.reserve(op.getNumOperands());
FormatLitOp prevLitOp;
for (auto operand : op.getOperands()) {
auto oldOperands = hasBeenFlattened ? flatOperands : op.getOperands();
for (auto operand : oldOperands) {
if (auto litOp = operand.getDefiningOp<FormatLitOp>()) {
if (!litOp.getLiteral().empty()) {
prevLitOp = litOp;
@ -263,7 +332,7 @@ LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
}
}
if (newOperands.size() == op.getNumOperands())
if (!hasBeenFlattened && newOperands.size() == op.getNumOperands())
return failure(); // Nothing changed
if (newOperands.empty())

View File

@ -117,3 +117,51 @@ hw.module @constant_fold3(in %zeroWitdh: i0, out res: !sim.fstring) {
%cat = sim.fmt.concat (%foo, %clf, %ccr, %null, %foo, %null, %cext)
hw.output %cat : !sim.fstring
}
// CHECK-LABEL: hw.module @flatten_concat1
// CHECK-DAG: %[[LH:.+]] = sim.fmt.lit "Hex: "
// CHECK-DAG: %[[LD:.+]] = sim.fmt.lit "Dec: "
// CHECK-DAG: %[[LB:.+]] = sim.fmt.lit "Bin: "
// CHECK-DAG: %[[FH:.+]] = sim.fmt.hex %val : i8
// CHECK-DAG: %[[FD:.+]] = sim.fmt.dec %val : i8
// CHECK-DAG: %[[FB:.+]] = sim.fmt.bin %val : i8
// CHECK-DAG: %[[CAT:.+]] = sim.fmt.concat (%[[LB]], %[[FB]], %[[LD]], %[[FD]], %[[LH]], %[[FH]])
// CHECK: hw.output %[[CAT]] : !sim.fstring
hw.module @flatten_concat1(in %val : i8, out res: !sim.fstring) {
%binLit = sim.fmt.lit "Bin: "
%binVal = sim.fmt.bin %val : i8
%binCat = sim.fmt.concat (%binLit, %binVal)
%decLit = sim.fmt.lit "Dec: "
%decVal = sim.fmt.dec %val : i8
%decCat = sim.fmt.concat (%decLit, %decVal, %nocat)
%nocat = sim.fmt.concat ()
%hexLit = sim.fmt.lit "Hex: "
%hexVal = sim.fmt.hex %val : i8
%hexCat = sim.fmt.concat (%hexLit, %hexVal)
%catcat = sim.fmt.concat (%binCat, %nocat, %decCat, %nocat, %hexCat)
hw.output %catcat : !sim.fstring
}
// CHECK-LABEL: hw.module @flatten_concat2
// CHECK-DAG: %[[F:.+]] = sim.fmt.lit "Foo"
// CHECK-DAG: %[[FF:.+]] = sim.fmt.lit "FooFoo"
// CHECK-DAG: %[[CHR:.+]] = sim.fmt.char %val : i8
// CHECK-DAG: %[[CAT:.+]] = sim.fmt.concat (%[[F]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[FF]], %[[CHR]], %[[F]])
// CHECK: hw.output %[[CAT]] : !sim.fstring
hw.module @flatten_concat2(in %val : i8, out res: !sim.fstring) {
%foo = sim.fmt.lit "Foo"
%char = sim.fmt.char %val : i8
%c = sim.fmt.concat (%foo, %char, %foo)
%cc = sim.fmt.concat (%c, %c)
%ccccc = sim.fmt.concat (%cc, %c, %cc)
hw.output %ccccc : !sim.fstring
}

View File

@ -1,8 +1,25 @@
// RUN: circt-opt %s --split-input-file --verify-diagnostics
// RUN: circt-opt %s --split-input-file --verify-diagnostics --canonicalize
hw.module @fmt_infinite_concat() {
hw.module @fmt_infinite_concat_verify() {
%lp = sim.fmt.lit ", {"
%rp = sim.fmt.lit "}"
// expected-error @below {{op is infinitely recursive.}}
%ordinal = sim.fmt.concat (%ordinal, %lp, %ordinal, %rp)
}
// -----
hw.module @fmt_infinite_concat_canonicalize(in %val : i8, out res: !sim.fstring) {
%c = sim.fmt.char %val : i8
%0 = sim.fmt.lit "Here we go round the"
%1 = sim.fmt.lit "prickly pear"
// expected-warning @below {{Cyclic concatenation detected.}}
%2 = sim.fmt.concat (%1, %c, %4)
// expected-warning @below {{Cyclic concatenation detected.}}
%3 = sim.fmt.concat (%1, %c, %1, %c, %2, %c)
// expected-warning @below {{Cyclic concatenation detected.}}
%4 = sim.fmt.concat (%0, %c, %3)
%5 = sim.fmt.lit "At five o'clock in the morning."
// expected-warning @below {{Cyclic concatenation detected.}}
%cat = sim.fmt.concat (%4, %c, %5)
hw.output %cat : !sim.fstring
}