Skip to content

Commit 04fc471

Browse files
authored
[mlir][linalg] Switch to use OpOperand* in ControlPropagationFn. (llvm#96697)
It's not easy to determine whether we want to propagate pack/unpack ops because we don't know the (producer, consumer) information. The revisions switch it to `OpOperand*`, so the control function can capture the (producer, consumer) pair. E.g., ``` Operation *producer = opOperand->get().getDefiningOp(); Operation *consumer = opOperand->getOwner(); ```
1 parent 915372a commit 04fc471

File tree

3 files changed

+23
-20
lines changed

3 files changed

+23
-20
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1652,7 +1652,7 @@ void populateElementwiseOpsFusionPatterns(
16521652

16531653
/// Function type which is used to control propagation of tensor.pack/unpack
16541654
/// ops.
1655-
using ControlPropagationFn = std::function<bool(Operation *op)>;
1655+
using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
16561656

16571657
/// Patterns to bubble up or down data layout ops across other operations.
16581658
void populateDataLayoutPropagationPatterns(

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
378378
return failure();
379379

380380
// User controlled propagation function.
381-
if (!controlFn(genericOp))
381+
if (!controlFn(&packOp.getSourceMutable()))
382382
return failure();
383383

384384
// TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
488488
return failure();
489489

490490
// User controlled propagation function.
491-
if (!controlFn(padOp))
491+
if (!controlFn(&packOp.getSourceMutable()))
492492
return failure();
493493

494494
if (!padOp.getResult().hasOneUse())
@@ -844,7 +844,7 @@ class BubbleUpPackOpThroughReshapeOp final
844844
}
845845

846846
// User controlled propagation function.
847-
if (!controlFn(srcOp))
847+
if (!controlFn(&packOp.getSourceMutable()))
848848
return failure();
849849

850850
return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -880,10 +880,13 @@ class BubbleUpPackOpThroughReshapeOp final
880880
/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
881881
/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
882882
/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
883-
static LogicalResult
884-
pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
885-
tensor::ExpandShapeOp expandOp,
886-
PatternRewriter &rewriter) {
883+
static LogicalResult pushDownUnPackOpThroughExpandShape(
884+
tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
885+
PatternRewriter &rewriter, ControlPropagationFn controlFn) {
886+
// User controlled propagation function.
887+
if (!controlFn(&expandOp.getSrcMutable()))
888+
return failure();
889+
887890
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
888891
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
889892
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -970,13 +973,10 @@ class PushDownUnPackOpThroughReshapeOp final
970973
}
971974

972975
Operation *consumerOp = *result.user_begin();
973-
// User controlled propagation function.
974-
if (!controlFn(consumerOp))
975-
return failure();
976-
977976
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
978977
.Case([&](tensor::ExpandShapeOp op) {
979-
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
978+
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
979+
controlFn);
980980
})
981981
.Default([](Operation *) { return failure(); });
982982
}
@@ -1038,7 +1038,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
10381038
/// inner_dims_pos = [3] inner_tiles = [32] into %0
10391039
///
10401040
static FailureOr<std::tuple<GenericOp, Value>>
1041-
pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
1041+
pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1042+
ControlPropagationFn controlFn) {
10421043
if (genericOp.getNumResults() != 1)
10431044
return failure();
10441045

@@ -1055,6 +1056,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
10551056
tensor::UnPackOp producerUnPackOp =
10561057
unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
10571058
assert(producerUnPackOp && "expect a valid UnPackOp");
1059+
1060+
if (!controlFn(unPackedOperand))
1061+
return failure();
1062+
10581063
auto packInfo =
10591064
getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
10601065
if (failed(packInfo))
@@ -1122,10 +1127,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
11221127

11231128
LogicalResult matchAndRewrite(GenericOp genericOp,
11241129
PatternRewriter &rewriter) const override {
1125-
if (!controlFn(genericOp))
1126-
return failure();
1127-
1128-
auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
1130+
auto genericAndRepl =
1131+
pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
11291132
if (failed(genericAndRepl))
11301133
return failure();
11311134
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1150,7 +1153,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
11501153
if (!unpackOp)
11511154
return failure();
11521155

1153-
if (!controlFn(padOp))
1156+
if (!controlFn(&padOp.getSourceMutable()))
11541157
return failure();
11551158

11561159
Location loc = padOp.getLoc();

mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct TestDataLayoutPropagationPass
3333
MLIRContext *context = &getContext();
3434
RewritePatternSet patterns(context);
3535
linalg::populateDataLayoutPropagationPatterns(
36-
patterns, [](Operation *op) { return true; });
36+
patterns, [](OpOperand *opOperand) { return true; });
3737
if (failed(
3838
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
3939
return signalPassFailure();

0 commit comments

Comments
 (0)