@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
378
378
return failure ();
379
379
380
380
// User controlled propagation function.
381
- if (!controlFn (genericOp ))
381
+ if (!controlFn (&packOp. getSourceMutable () ))
382
382
return failure ();
383
383
384
384
// TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
488
488
return failure ();
489
489
490
490
// User controlled propagation function.
491
- if (!controlFn (padOp ))
491
+ if (!controlFn (&packOp. getSourceMutable () ))
492
492
return failure ();
493
493
494
494
if (!padOp.getResult ().hasOneUse ())
@@ -844,7 +844,7 @@ class BubbleUpPackOpThroughReshapeOp final
844
844
}
845
845
846
846
// User controlled propagation function.
847
- if (!controlFn (srcOp ))
847
+ if (!controlFn (&packOp. getSourceMutable () ))
848
848
return failure ();
849
849
850
850
return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -880,10 +880,13 @@ class BubbleUpPackOpThroughReshapeOp final
880
880
// / %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
881
881
// / inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
882
882
// / : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
883
- static LogicalResult
884
- pushDownUnPackOpThroughExpandShape (tensor::UnPackOp unPackOp,
885
- tensor::ExpandShapeOp expandOp,
886
- PatternRewriter &rewriter) {
883
+ static LogicalResult pushDownUnPackOpThroughExpandShape (
884
+ tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
885
+ PatternRewriter &rewriter, ControlPropagationFn controlFn) {
886
+ // User controlled propagation function.
887
+ if (!controlFn (&expandOp.getSrcMutable ()))
888
+ return failure ();
889
+
887
890
SmallVector<int64_t > innerTileSizes = unPackOp.getStaticTiles ();
888
891
ArrayRef<int64_t > innerDimsPos = unPackOp.getInnerDimsPos ();
889
892
ArrayRef<int64_t > outerDimsPerm = unPackOp.getOuterDimsPerm ();
@@ -970,13 +973,10 @@ class PushDownUnPackOpThroughReshapeOp final
970
973
}
971
974
972
975
Operation *consumerOp = *result.user_begin ();
973
- // User controlled propagation function.
974
- if (!controlFn (consumerOp))
975
- return failure ();
976
-
977
976
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
978
977
.Case ([&](tensor::ExpandShapeOp op) {
979
- return pushDownUnPackOpThroughExpandShape (unPackOp, op, rewriter);
978
+ return pushDownUnPackOpThroughExpandShape (unPackOp, op, rewriter,
979
+ controlFn);
980
980
})
981
981
.Default ([](Operation *) { return failure (); });
982
982
}
@@ -1038,7 +1038,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1038
1038
// / inner_dims_pos = [3] inner_tiles = [32] into %0
1039
1039
// /
1040
1040
static FailureOr<std::tuple<GenericOp, Value>>
1041
- pushDownUnPackOpThroughGenericOp (RewriterBase &rewriter, GenericOp genericOp) {
1041
+ pushDownUnPackOpThroughGenericOp (RewriterBase &rewriter, GenericOp genericOp,
1042
+ ControlPropagationFn controlFn) {
1042
1043
if (genericOp.getNumResults () != 1 )
1043
1044
return failure ();
1044
1045
@@ -1055,6 +1056,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
1055
1056
tensor::UnPackOp producerUnPackOp =
1056
1057
unPackedOperand->get ().getDefiningOp <tensor::UnPackOp>();
1057
1058
assert (producerUnPackOp && " expect a valid UnPackOp" );
1059
+
1060
+ if (!controlFn (unPackedOperand))
1061
+ return failure ();
1062
+
1058
1063
auto packInfo =
1059
1064
getPackingInfoFromOperand (unPackedOperand, genericOp, producerUnPackOp);
1060
1065
if (failed (packInfo))
@@ -1122,10 +1127,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1122
1127
1123
1128
LogicalResult matchAndRewrite (GenericOp genericOp,
1124
1129
PatternRewriter &rewriter) const override {
1125
- if (!controlFn (genericOp))
1126
- return failure ();
1127
-
1128
- auto genericAndRepl = pushDownUnPackOpThroughGenericOp (rewriter, genericOp);
1130
+ auto genericAndRepl =
1131
+ pushDownUnPackOpThroughGenericOp (rewriter, genericOp, controlFn);
1129
1132
if (failed (genericAndRepl))
1130
1133
return failure ();
1131
1134
rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
@@ -1150,7 +1153,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1150
1153
if (!unpackOp)
1151
1154
return failure ();
1152
1155
1153
- if (!controlFn (padOp))
1156
+ if (!controlFn (& padOp. getSourceMutable () ))
1154
1157
return failure ();
1155
1158
1156
1159
Location loc = padOp.getLoc ();
0 commit comments