[mlir][DRR] Add location directive

Summary:
Add directive to indicate the location to give to op being created. This
directive is optional and if unused the location will still be the fused
location of all source operations.

Currently this directive only works with other op locations, reusing an
existing op location or a fusion of op locations. But doesn't yet support
supplying metadata for the FusedLoc.

Based off initial revision by antiagainst@ and effectively mirrors GlobalIsel
debug_locations directive.

Differential Revision: https://reviews.llvm.org/D77649
This commit is contained in:
Jacques Pienaar 2020-04-07 07:44:19 -07:00
parent c8de17bca6
commit 3f7439b280
7 changed files with 179 additions and 29 deletions

View File

@ -657,9 +657,53 @@ pattern. This is based on the heuristics and assumptions that:
The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a
pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value. pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value.
## Special directives ## Rewrite directives
[TODO] ### `location`
By default the C++ pattern expanded from a DRR pattern uses the fused location
of all source ops as the location for all generated ops. This is not always the
best location mapping relationship. For such cases, DRR provides the `location`
directive to provide finer control.
`location` is of the following syntax:
```tablgen
(location $symbol0, $symbol1, ...)
```
where all `$symbol` should be bound previously in the pattern.
`location` must be used as the last argument to an op creation. For example,
```tablegen
def : Pat<(LocSrc1Op:$src1 (LocSrc2Op:$src2 ...),
(LocDst1Op (LocDst2Op ..., (location $src2)))>;
```
In the above pattern, the generated `LocDst2Op` will use the matched location
of `LocSrc2Op` while the root `LocDst1Op` node will still se the fused location
of all source Ops.
### `replaceWithValue`
The `replaceWithValue` directive is used to eliminate a matched op by replacing
all of it uses with a captured value. It is of the following syntax:
```tablegen
(replaceWithValue $symbol)
```
where `$symbol` should be a symbol bound previously in the pattern.
For example,
```tablegen
def : Pat<(Foo $input), (replaceWithValue $input)>;
```
The above pattern removes the `Foo` and replaces all uses of `Foo` with
`$input`.
## Debugging Tips ## Debugging Tips

View File

@ -2179,9 +2179,14 @@ class NativeCodeCall<string expr> {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Common directives // Rewrite directives
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Directive used in result pattern to specify the location of the generated
// op. This directive must be used as the last argument to the op creation
// DAG construct. The arguments to location must be previously captured symbol.
def location;
// Directive used in result pattern to indicate that no new op are generated, // Directive used in result pattern to indicate that no new op are generated,
// so to replace the matched DAG with an existing SSA value. // so to replace the matched DAG with an existing SSA value.
def replaceWithValue; def replaceWithValue;

View File

@ -159,6 +159,9 @@ public:
// value. // value.
bool isReplaceWithValue() const; bool isReplaceWithValue() const;
// Returns whether this DAG represents the location of an op creation.
bool isLocationDirective() const;
// Returns true if this DAG node is wrapping native code call. // Returns true if this DAG node is wrapping native code call.
bool isNativeCodeCall() const; bool isNativeCodeCall() const;

View File

@ -103,7 +103,7 @@ bool tblgen::DagNode::isNativeCodeCall() const {
} }
bool tblgen::DagNode::isOperation() const { bool tblgen::DagNode::isOperation() const {
return !(isNativeCodeCall() || isReplaceWithValue()); return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
} }
llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
@ -159,6 +159,11 @@ bool tblgen::DagNode::isReplaceWithValue() const {
return dagOpDef->getName() == "replaceWithValue"; return dagOpDef->getName() == "replaceWithValue";
} }
bool tblgen::DagNode::isLocationDirective() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "location";
}
void tblgen::DagNode::print(raw_ostream &os) const { void tblgen::DagNode::print(raw_ostream &os) const {
if (node) if (node)
node->print(os); node->print(os);
@ -533,7 +538,14 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
auto numOpArgs = op.getNumArgs(); auto numOpArgs = op.getNumArgs();
auto numTreeArgs = tree.getNumArgs(); auto numTreeArgs = tree.getNumArgs();
if (numOpArgs != numTreeArgs) { // The pattern might have the last argument specifying the location.
bool hasLocDirective = false;
if (numTreeArgs != 0) {
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
hasLocDirective = lastArg.isLocationDirective();
}
if (numOpArgs != numTreeArgs - hasLocDirective) {
auto err = formatv("op '{0}' argument number mismatch: " auto err = formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition", "{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs); op.getOperationName(), numTreeArgs, numOpArgs);

View File

@ -501,6 +501,20 @@ def StringAttrPrettyNameOp
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
} }
//===----------------------------------------------------------------------===//
// Test Locations
//===----------------------------------------------------------------------===//
def TestLocationSrcOp : TEST_Op<"loc_src"> {
let arguments = (ins I32:$input);
let results = (outs I32:$output);
}
def TestLocationDstOp : TEST_Op<"loc_dst", [SameOperandsAndResultType]> {
let arguments = (ins I32:$input);
let results = (outs I32:$output);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Test Patterns // Test Patterns
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -995,6 +1009,18 @@ def : Pat<(OneI32ResultOp),
(replaceWithValue $results__2), (replaceWithValue $results__2),
ConstantAttr<I32Attr, "2">)>; ConstantAttr<I32Attr, "2">)>;
//===----------------------------------------------------------------------===//
// Test Patterns (Location)
// Test that we can specify locations for generated ops.
def : Pat<(TestLocationSrcOp:$res1
(TestLocationSrcOp:$res2
(TestLocationSrcOp:$res3 $input))),
(TestLocationDstOp
(TestLocationDstOp
(TestLocationDstOp $input, (location $res1))),
(location $res2, $res3))>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Test Legalization // Test Legalization
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s // RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: verifyFusedLocs // CHECK-LABEL: verifyFusedLocs
func @verifyFusedLocs(%arg0 : i32) -> i32 { func @verifyFusedLocs(%arg0 : i32) -> i32 {
@ -10,6 +10,21 @@ func @verifyFusedLocs(%arg0 : i32) -> i32 {
return %result : i32 return %result : i32
} }
// CHECK-LABEL: verifyDesignatedLoc
func @verifyDesignatedLoc(%arg0 : i32) -> i32 {
%0 = "test.loc_src"(%arg0) : (i32) -> i32 loc("loc3")
%1 = "test.loc_src"(%0) : (i32) -> i32 loc("loc2")
%2 = "test.loc_src"(%1) : (i32) -> i32 loc("loc1")
// CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("loc1")
// CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused[
// CHECK-SAME: "loc1"
// CHECK-SAME: "loc3"
// CHECK-SAME: "loc2"
// CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused["loc2", "loc3"])
return %1 : i32
}
// CHECK-LABEL: verifyZeroResult // CHECK-LABEL: verifyZeroResult
func @verifyZeroResult(%arg0 : i32) { func @verifyZeroResult(%arg0 : i32) {
// CHECK: "test.op_i"(%arg0) : (i32) -> () // CHECK: "test.op_i"(%arg0) : (i32) -> ()

View File

@ -109,9 +109,11 @@ private:
// calling native C++ code. // calling native C++ code.
std::string handleReplaceWithNativeCodeCall(DagNode resultTree); std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
// Returns the C++ expression referencing the old value serving as the // Returns the symbol of the old value serving as the replacement.
// replacement. StringRef handleReplaceWithValue(DagNode tree);
std::string handleReplaceWithValue(DagNode tree);
// Returns the symbol of the value whose location to use.
std::string handleUseLocationOf(DagNode tree);
// Emits the C++ statement to build a new op out of the given DAG `tree` and // Emits the C++ statement to build a new op out of the given DAG `tree` and
// returns the variable name that this op is assigned to. If the root op in // returns the variable name that this op is assigned to. If the root op in
@ -580,11 +582,11 @@ void PatternEmitter::emitRewriteLogic() {
PrintFatalError(loc, error); PrintFatalError(loc, error);
} }
os.indent(4) << "auto loc = rewriter.getFusedLoc({"; os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({";
for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
} }
os << "}); (void)loc;\n"; os << "}); (void)odsLoc;\n";
// Process auxiliary result patterns. // Process auxiliary result patterns.
for (int i = 0; i < replStartIndex; ++i) { for (int i = 0; i < replStartIndex; ++i) {
@ -640,15 +642,19 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
LLVM_DEBUG(resultTree.print(llvm::dbgs())); LLVM_DEBUG(resultTree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n'); LLVM_DEBUG(llvm::dbgs() << '\n');
if (resultTree.isLocationDirective()) {
PrintFatalError(loc,
"location directive can only be used with op creation");
}
if (resultTree.isNativeCodeCall()) { if (resultTree.isNativeCodeCall()) {
auto symbol = handleReplaceWithNativeCodeCall(resultTree); auto symbol = handleReplaceWithNativeCodeCall(resultTree);
symbolInfoMap.bindValue(symbol); symbolInfoMap.bindValue(symbol);
return symbol; return symbol;
} }
if (resultTree.isReplaceWithValue()) { if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree); return handleReplaceWithValue(resultTree).str();
}
// Normal op creation. // Normal op creation.
auto symbol = handleOpCreation(resultTree, resultIndex, depth); auto symbol = handleOpCreation(resultTree, resultIndex, depth);
@ -660,7 +666,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
return symbol; return symbol;
} }
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
assert(tree.isReplaceWithValue()); assert(tree.isReplaceWithValue());
if (tree.getNumArgs() != 1) { if (tree.getNumArgs() != 1) {
@ -672,7 +678,30 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
} }
return std::string(tree.getArgName(0)); return tree.getArgName(0);
}
std::string PatternEmitter::handleUseLocationOf(DagNode tree) {
assert(tree.isLocationDirective());
auto lookUpArgLoc = [this, &tree](int idx) {
const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
};
if (tree.getNumArgs() != 1) {
std::string ret;
llvm::raw_string_ostream os(ret);
os << "rewriter.getFusedLoc({";
for (int i = 0, e = tree.getNumArgs(); i != e; ++i)
os << (i ? ", " : "") << lookUpArgLoc(i);
os << "})";
return os.str();
}
if (!tree.getSymbol().empty())
PrintFatalError(loc, "cannot bind symbol to location");
return lookUpArgLoc(0);
} }
std::string PatternEmitter::handleOpArgument(DagLeaf leaf, std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
@ -753,14 +782,28 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
Operator &resultOp = tree.getDialectOp(opMap); Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs = resultOp.getNumArgs(); auto numOpArgs = resultOp.getNumArgs();
auto numPatArgs = tree.getNumArgs();
if (numOpArgs != tree.getNumArgs()) { // Get the location for this operation if explicitly provided.
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " std::string locToUse;
"{1} in pattern vs. {2} in definition", if (numPatArgs != 0) {
resultOp.getOperationName(), tree.getNumArgs(), if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
numOpArgs)); if (lastArg.isLocationDirective())
locToUse = handleUseLocationOf(lastArg);
} }
auto inPattern = numPatArgs - !locToUse.empty();
if (numOpArgs != inPattern) {
PrintFatalError(loc,
formatv("resultant op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
resultOp.getOperationName(), inPattern, numOpArgs));
}
// If no explicit location is given, use the default, all fused, location.
if (locToUse.empty())
locToUse = "odsLoc";
// A map to collect all nested DAG child nodes' names, with operand index as // A map to collect all nested DAG child nodes' names, with operand index as
// the key. This includes both bound and unbound child nodes. // the key. This includes both bound and unbound child nodes.
ChildNodeIndexNameMap childNodeNames; ChildNodeIndexNameMap childNodeNames;
@ -769,9 +812,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// create ops for them and remember the symbol names for them, so that we can // create ops for them and remember the symbol names for them, so that we can
// use the results in the current node. This happens in a recursive manner. // use the results in the current node. This happens in a recursive manner.
for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
if (auto child = tree.getArgAsNestedDag(i)) { if (auto child = tree.getArgAsNestedDag(i))
childNodeNames[i] = handleResultPattern(child, i, depth + 1); childNodeNames[i] = handleResultPattern(child, i, depth + 1);
}
} }
// The name of the local variable holding this op. // The name of the local variable holding this op.
@ -811,10 +853,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// First prepare local variables for op arguments used in builder call. // First prepare local variables for op arguments used in builder call.
createAggregateLocalVarsForOpArgs(tree, childNodeNames); createAggregateLocalVarsForOpArgs(tree, childNodeNames);
// Then create the op. // Then create the op.
os.indent(6) << formatv( os.indent(6) << formatv(
"{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n", "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n",
valuePackName, resultOp.getQualCppClassName()); valuePackName, resultOp.getQualCppClassName(), locToUse);
os.indent(4) << "}\n"; os.indent(4) << "}\n";
return resultValue; return resultValue;
} }
@ -831,8 +874,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// here given that it's easier for developers to write compared to // here given that it's easier for developers to write compared to
// aggregate-parameter builders. // aggregate-parameter builders.
createSeparateLocalVarsForOpArgs(tree, childNodeNames); createSeparateLocalVarsForOpArgs(tree, childNodeNames);
os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
resultOp.getQualCppClassName()); os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
resultOp.getQualCppClassName(), locToUse);
supplyValuesForOpArgs(tree, childNodeNames); supplyValuesForOpArgs(tree, childNodeNames);
os << "\n );\n"; os << "\n );\n";
os.indent(4) << "}\n"; os.indent(4) << "}\n";
@ -858,9 +902,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
"tblgen_types.push_back(v.getType()); }\n", "tblgen_types.push_back(v.getType()); }\n",
resultIndex + i); resultIndex + i);
} }
os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, " os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
"tblgen_values, tblgen_attrs);\n", "tblgen_values, tblgen_attrs);\n",
valuePackName, resultOp.getQualCppClassName()); valuePackName, resultOp.getQualCppClassName(),
locToUse);
os.indent(4) << "}\n"; os.indent(4) << "}\n";
return resultValue; return resultValue;
} }