Skip to content

Commit 1dccdf7

Browse files
[mlir][linalg][transform][python] Add type arg to MatchOp extension.
The extension class to MatchOp has a class method called match_op_names. The previous version of that function did not allow to specify the result type. This, however, may be useful/necessary if the op consuming the resulting handle requires a particular type (such as the bufferization.EmptyTensorToAllocTensorOp). This patch adds an overload to match_op_names that allows to specify the result type. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155567
1 parent 5bc8364 commit 1dccdf7

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

mlir/python/mlir/dialects/_structured_transform_ops_ext.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,52 @@ def __init__(
8585
class MatchOp:
8686
"""Specialization for MatchOp class."""
8787

88+
@overload
8889
@classmethod
8990
def match_op_names(
90-
MatchOp,
91+
cls,
9192
target: Union[Operation, Value],
9293
names: Sequence[str],
94+
*,
9395
loc=None,
9496
ip=None,
9597
):
96-
pdl_operation_type = pdl.OperationType.get()
97-
return MatchOp(
98-
pdl_operation_type,
98+
...
99+
100+
@overload
101+
@classmethod
102+
def match_op_names(
103+
cls,
104+
result_type: Type,
105+
target: Union[Operation, Value],
106+
names: Sequence[str],
107+
*,
108+
loc=None,
109+
ip=None,
110+
):
111+
...
112+
113+
@classmethod
114+
def match_op_names(
115+
cls,
116+
result_type_or_target: Union[Type, Operation, Value],
117+
target_or_names: Union[Operation, Value, Sequence[str]],
118+
names_or_none: Optional[Sequence[str]] = None,
119+
*,
120+
loc=None,
121+
ip=None,
122+
):
123+
if isinstance(result_type_or_target, Type):
124+
result_type = result_type_or_target
125+
target = target_or_names
126+
names = names_or_none
127+
else:
128+
result_type = transform.AnyOpType.get()
129+
target = result_type_or_target
130+
names = target_or_names
131+
132+
return cls(
133+
result_type,
99134
_get_op_result_or_value(target),
100135
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
101136
loc=loc,

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,38 @@ def testInterchange():
5757
# CHECK: iterator_interchange = [1, 0]
5858

5959

60+
@run
61+
def testMatchOpNames():
62+
sequence = transform.SequenceOp(
63+
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
64+
)
65+
with InsertionPoint(sequence.body):
66+
structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
67+
transform.YieldOp()
68+
# CHECK-LABEL: TEST: testMatchOpNames
69+
# CHECK: transform.structured.match ops
70+
# CHECK-SAME: ["test.dummy"]
71+
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
72+
73+
74+
@run
75+
def testMatchOpNamesTyped():
76+
sequence = transform.SequenceOp(
77+
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
78+
)
79+
with InsertionPoint(sequence.body):
80+
structured.MatchOp.match_op_names(
81+
transform.OperationType.get("test.dummy"),
82+
sequence.bodyTarget,
83+
["test.dummy"],
84+
)
85+
transform.YieldOp()
86+
# CHECK-LABEL: TEST: testMatchOpNamesTyped
87+
# CHECK: transform.structured.match ops
88+
# CHECK-SAME: ["test.dummy"]
89+
# CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
90+
91+
6092
@run
6193
def testMultitileSizes():
6294
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)