Skip to content

Commit 29429d1

Browse files
committed
[drr] Add $_loc special directive for NativeCodeCall
Allows propagating the location to ops created via NativeCodeCall. Differential Revision: https://reviews.llvm.org/D85704
1 parent 4a646ca commit 29429d1

File tree

4 files changed

+36
-19
lines changed

4 files changed

+36
-19
lines changed

mlir/docs/DeclarativeRewrites.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,10 +384,12 @@ In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N`. The former
384384
is called _special placeholder_, while the latter is called _positional
385385
placeholder_.
386386

387-
`NativeCodeCall` right now only supports two special placeholders: `$_builder`
388-
and `$_self`:
387+
`NativeCodeCall` right now only supports three special placeholders:
388+
`$_builder`, `$_loc`, and `$_self`:
389389

390390
* `$_builder` will be replaced by the current `mlir::PatternRewriter`.
391+
* `$_loc` will be replaced by the fused location or custom location (as
392+
determined by location directive).
391393
* `$_self` will be replaced with the entity `NativeCodeCall` is attached to.
392394

393395
We have seen how `$_builder` can be used in the above; it allows us to pass a

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,8 @@ def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
724724
// Test that NativeCodeCall is not ignored if it is not used to directly
725725
// replace the matched root op.
726726
def : Pattern<(OpNativeCodeCall3 $input),
727-
[(NativeCodeCall<"createOpI($_builder, $0)"> $input), (OpK)]>;
727+
[(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
728+
(OpK)]>;
728729

729730
// Test AllAttrConstraintsOf.
730731
def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
2121
return choice.getValue() ? input1 : input2;
2222
}
2323

24-
static void createOpI(PatternRewriter &rewriter, Value input) {
25-
rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
24+
static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
25+
rewriter.create<OpI>(loc, input);
2626
}
2727

2828
static void handleNoResultOp(PatternRewriter &rewriter,

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ class PatternEmitter {
112112
// Returns the symbol of the old value serving as the replacement.
113113
StringRef handleReplaceWithValue(DagNode tree);
114114

115+
// Returns the location value to use.
116+
std::pair<bool, std::string> getLocation(DagNode tree);
117+
115118
// Returns the location value to use.
116119
std::string handleLocationDirective(DagNode tree);
117120

@@ -779,13 +782,18 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
779782
PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
780783
Twine(tree.getNumArgs()));
781784
}
782-
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
785+
bool hasLocationDirective;
786+
std::string locToUse;
787+
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
788+
789+
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
783790
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
784791
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
785792
<< " replacement: " << attrs[i] << "\n");
786793
}
787-
return std::string(tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3],
788-
attrs[4], attrs[5], attrs[6], attrs[7]));
794+
return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
795+
attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
796+
attrs[6], attrs[7]));
789797
}
790798

791799
int PatternEmitter::getNodeValueCount(DagNode node) {
@@ -804,6 +812,20 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
804812
return 1;
805813
}
806814

815+
std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
816+
auto numPatArgs = tree.getNumArgs();
817+
818+
if (numPatArgs != 0) {
819+
if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
820+
if (lastArg.isLocationDirective()) {
821+
return std::make_pair(true, handleLocationDirective(lastArg));
822+
}
823+
}
824+
825+
// If no explicit location is given, use the default, all fused, location.
826+
return std::make_pair(false, "odsLoc");
827+
}
828+
807829
std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
808830
int depth) {
809831
LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
@@ -814,26 +836,18 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
814836
auto numOpArgs = resultOp.getNumArgs();
815837
auto numPatArgs = tree.getNumArgs();
816838

817-
// Get the location for this operation if explicitly provided.
839+
bool hasLocationDirective;
818840
std::string locToUse;
819-
if (numPatArgs != 0) {
820-
if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
821-
if (lastArg.isLocationDirective())
822-
locToUse = handleLocationDirective(lastArg);
823-
}
841+
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
824842

825-
auto inPattern = numPatArgs - !locToUse.empty();
843+
auto inPattern = numPatArgs - hasLocationDirective;
826844
if (numOpArgs != inPattern) {
827845
PrintFatalError(loc,
828846
formatv("resultant op '{0}' argument number mismatch: "
829847
"{1} in pattern vs. {2} in definition",
830848
resultOp.getOperationName(), inPattern, numOpArgs));
831849
}
832850

833-
// If no explicit location is given, use the default, all fused, location.
834-
if (locToUse.empty())
835-
locToUse = "odsLoc";
836-
837851
// A map to collect all nested DAG child nodes' names, with operand index as
838852
// the key. This includes both bound and unbound child nodes.
839853
ChildNodeIndexNameMap childNodeNames;

0 commit comments

Comments
 (0)