Skip to content

Commit e4e0bf6

Browse files
[mlir][Vector] Split transform.vector.lower_mask in 2 ops.
This gives us better control to lower masked operations independently of the create mask operations. It is often useful to maintain high-level mask information instead of lowering it too early to too fine-grained form. Differential Revision: https://reviews.llvm.org/D148162
1 parent e323029 commit e4e0bf6

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,31 @@ def LowerContractionOp : TransformWithPatternsOp<"vector.lower_contraction"> {
122122
}];
123123
}
124124

125-
def LowerMaskOp : TransformWithPatternsOp<"vector.lower_mask"> {
125+
def LowerMasksOp : TransformWithPatternsOp<"vector.lower_masks"> {
126126
let description = [{
127-
Indicates that the vector mask operations nested under the isolated from
128-
above op `target` should be lowered to finer-grained vector primitives.
127+
Indicates that the vector.create_mask and vector.constant_mask operations
128+
nested under the isolated from above op `target` should be lowered to
129+
finer-grained vector primitives.
130+
131+
This is usually a late step that is run after bufferization as part of the
132+
process of lowering to e.g. LLVM or NVVM.
133+
}];
134+
135+
let arguments = (ins TransformHandleTypeInterface:$target);
136+
let results = (outs TransformHandleTypeInterface:$results);
137+
138+
let assemblyFormat = [{
139+
$target
140+
attr-dict
141+
`:` functional-type($target, results)
142+
}];
143+
}
144+
145+
def LowerMaskedTransfersOp : TransformWithPatternsOp<"vector.lower_masked_transfers"> {
146+
let description = [{
147+
Indicates that masked vector.transfer and vector.gather operations nested
148+
under the isolated from above op `target` should be lowered to finer-grained
149+
vector primitives.
129150

130151
This is usually a late step that is run after bufferization as part of the
131152
process of lowering to e.g. LLVM or NVVM.

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,19 @@ void transform::LowerContractionOp::populatePatterns(
6363
}
6464

6565
//===----------------------------------------------------------------------===//
66-
// LowerMaskOp
66+
// LowerMasksOp
6767
//===----------------------------------------------------------------------===//
6868

69-
void transform::LowerMaskOp::populatePatterns(RewritePatternSet &patterns) {
69+
void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) {
7070
populateVectorMaskOpLoweringPatterns(patterns);
71+
}
72+
73+
//===----------------------------------------------------------------------===//
74+
// LowerMaskedTransfersOp
75+
//===----------------------------------------------------------------------===//
76+
77+
void transform::LowerMaskedTransfersOp::populatePatterns(
78+
RewritePatternSet &patterns) {
7179
populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
7280
}
7381

mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,36 @@ transform.sequence failures(propagate) {
9696
%f = transform.structured.match ops{["func.func"]} in %module_op
9797
: (!pdl.operation) -> !pdl.operation
9898

99-
transform.vector.lower_mask %f
99+
transform.vector.lower_masks %f
100+
: (!pdl.operation) -> !pdl.operation
101+
}
102+
103+
// -----
104+
105+
// CHECK-LABEL: func @transfer_read_3d(
106+
func.func @transfer_read_3d(
107+
%t: tensor<?x?x?xf32>, %arg0: index, %arg1: index, %arg2: index)
108+
-> vector<2x1x7xf32> {
109+
%c0 = arith.constant 0 : index
110+
%f0 = arith.constant 0.0 : f32
111+
// CHECK: %[[mask:.*]] = vector.create_mask
112+
// CHECK-NOT: vector.mask
113+
// CHECK: vector.transfer_read {{.*}}, %[[mask]] {in_bounds = [true, true, true]}
114+
// CHECK-SAME: : tensor<?x?x?xf32>, vector<2x1x7xf32>
115+
%0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
116+
%1 = vector.mask %0 {
117+
vector.transfer_read %t[%c0, %c0, %c0], %f0 {in_bounds = [true, true, true]}
118+
: tensor<?x?x?xf32>, vector<2x1x7xf32>
119+
} : vector<2x1x7xi1> -> vector<2x1x7xf32>
120+
121+
return %1: vector<2x1x7xf32>
122+
}
123+
124+
transform.sequence failures(propagate) {
125+
^bb1(%module_op: !pdl.operation):
126+
%f = transform.structured.match ops{["func.func"]} in %module_op
127+
: (!pdl.operation) -> !pdl.operation
128+
129+
transform.vector.lower_masked_transfers %f
100130
: (!pdl.operation) -> !pdl.operation
101131
}

0 commit comments

Comments
 (0)