Skip to content

Commit 26722f5

Browse files
authored
[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization (llvm#84225)
The current canonicalization of `memref.dim` operating on the result of `memref.reshape` into `memref.load` is incorrect as it doesn't check whether the `index` operand of `memref.dim` dominates the source `memref.reshape` op. It always introduces `memref.load` right after `memref.reshape` to ensure the `memref` is not mutated before the `memref.load` call. As a result, the following error is observed: ``` $> mlir-opt --canonicalize input.mlir func.func @reshape_dim(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index { %c4 = arith.constant 4 : index %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32> %0 = arith.muli %arg2, %c4 : index %dim = memref.dim %reshape, %0 : memref<*xf32> return %dim : index } ``` results in: ``` dominator.mlir:22:12: error: operand rust-lang#1 does not dominate this use %dim = memref.dim %reshape, %0 : memref<*xf32> ^ dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) <{nontemporal = false}> : (memref<?xindex>, index) -> index dominator.mlir:21:10: note: operand defined here (op in the same block) %0 = arith.muli %arg2, %c4 : index ``` Properly fixing this issue requires a dominator analysis which is expensive to run within a canonicalization pattern. So, this patch fixes the canonicalization pattern by being more strict/conservative about the legality condition in which we perform this canonicalization. The more general pattern is also added to `tensor.dim`. Since tensors are immutable we don't need to worry about where to introduce the `tensor.extract` call after canonicalization.
1 parent 2a30684 commit 26722f5

File tree

4 files changed

+191
-2
lines changed

4 files changed

+191
-2
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

+31-1
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,37 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
10801080
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
10811081

10821082
if (!reshape)
1083-
return failure();
1083+
return rewriter.notifyMatchFailure(
1084+
dim, "Dim op is not defined by a reshape op.");
1085+
1086+
// dim of a memref reshape can be folded if dim.getIndex() dominates the
1087+
// reshape. Instead of using `DominanceInfo` (which is usually costly) we
1088+
// cheaply check that either of the following conditions hold:
1089+
// 1. dim.getIndex() is defined in the same block as reshape but before
1090+
// reshape.
1091+
// 2. dim.getIndex() is defined in a parent block of
1092+
// reshape.
1093+
1094+
// Check condition 1
1095+
if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1096+
if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1097+
if (reshape->isBeforeInBlock(definingOp)) {
1098+
return rewriter.notifyMatchFailure(
1099+
dim,
1100+
"dim.getIndex is not defined before reshape in the same block.");
1101+
}
1102+
} // else dim.getIndex is a block argument to reshape->getBlock and
1103+
// dominates reshape
1104+
} // Check condition 2
1105+
else if (dim->getBlock() != reshape->getBlock() &&
1106+
!dim.getIndex().getParentRegion()->isProperAncestor(
1107+
reshape->getParentRegion())) {
1108+
// If dim and reshape are in the same block but dim.getIndex() isn't, we
1109+
// already know dim.getIndex() dominates reshape without calling
1110+
// `isProperAncestor`
1111+
return rewriter.notifyMatchFailure(
1112+
dim, "dim.getIndex does not dominate reshape.");
1113+
}
10841114

10851115
// Place the load directly after the reshape to ensure that the shape memref
10861116
// was not mutated.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

+27-1
Original file line numberDiff line numberDiff line change
@@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
824824
return success();
825825
}
826826
};
827+
828+
/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
829+
/// operand.
830+
struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
831+
using OpRewritePattern<DimOp>::OpRewritePattern;
832+
833+
LogicalResult matchAndRewrite(DimOp dim,
834+
PatternRewriter &rewriter) const override {
835+
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
836+
837+
if (!reshape)
838+
return failure();
839+
840+
// Since tensors are immutable we don't need to worry about where to place
841+
// the extract call
842+
rewriter.setInsertionPointAfter(dim);
843+
Location loc = dim.getLoc();
844+
Value extract =
845+
rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
846+
if (extract.getType() != dim.getType())
847+
extract =
848+
rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
849+
rewriter.replaceOp(dim, extract);
850+
return success();
851+
}
852+
};
827853
} // namespace
828854

829855
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
830856
MLIRContext *context) {
831-
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
857+
results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
832858
}
833859

834860
//===----------------------------------------------------------------------===//

mlir/test/Dialect/MemRef/canonicalize.mlir

+53
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,59 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
313313

314314
// -----
315315

316+
// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
317+
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
318+
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
319+
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
320+
// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
321+
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
322+
// CHECK-NOT: memref.dim
323+
// CHECK: return %[[DIM]] : index
324+
func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
325+
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
326+
%dim = memref.dim %reshape, %arg2 : memref<*xf32>
327+
return %dim : index
328+
}
329+
330+
// -----
331+
332+
// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
333+
// CHECK-LABEL: func @dim_of_memref_reshape_for(
334+
// CHECK: memref.reshape
335+
// CHECK: memref.dim
336+
// CHECK-NOT: memref.load
337+
func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
338+
%c0 = arith.constant 0 : index
339+
%c1 = arith.constant 1 : index
340+
%c4 = arith.constant 4 : index
341+
342+
%0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
343+
344+
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
345+
%2 = memref.dim %0, %arg2 : memref<*xf32>
346+
%3 = arith.muli %arg3, %2 : index
347+
scf.yield %3 : index
348+
}
349+
return %1 : index
350+
}
351+
352+
// -----
353+
354+
// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
355+
// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
356+
// CHECK: memref.reshape
357+
// CHECK: memref.dim
358+
// CHECK-NOT: memref.load
359+
func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
360+
%c4 = arith.constant 4 : index
361+
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
362+
%0 = arith.muli %arg2, %c4 : index
363+
%dim = memref.dim %reshape, %0 : memref<*xf32>
364+
return %dim : index
365+
}
366+
367+
// -----
368+
316369
// CHECK-LABEL: func @alloc_const_fold
317370
func.func @alloc_const_fold() -> memref<?xf32> {
318371
// CHECK-NEXT: memref.alloc() : memref<4xf32>

mlir/test/Dialect/Tensor/canonicalize.mlir

+80
Original file line numberDiff line numberDiff line change
@@ -2287,3 +2287,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
22872287
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
22882288
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
22892289
// CHECK: return %[[SRC]]
2290+
2291+
// -----
2292+
2293+
// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2294+
// CHECK-LABEL: func @dim_of_reshape(
2295+
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
2296+
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
2297+
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
2298+
// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
2299+
// CHECK-NOT: tensor.store
2300+
// CHECK-NOT: tensor.dim
2301+
// CHECK-NOT: tensor.reshape
2302+
// CHECK: return %[[DIM]] : index
2303+
func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
2304+
-> index {
2305+
%c3 = arith.constant 3 : index
2306+
%0 = tensor.reshape %arg0(%arg1)
2307+
: (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2308+
// Update the shape to test that the load ends up in the right place.
2309+
tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
2310+
%1 = tensor.dim %0, %c3 : tensor<*xf32>
2311+
return %1 : index
2312+
}
2313+
2314+
// -----
2315+
2316+
// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2317+
// CHECK-LABEL: func @dim_of_reshape_i32(
2318+
// CHECK: tensor.extract
2319+
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
2320+
// CHECK-NOT: tensor.dim
2321+
// CHECK-NOT: tensor.reshape
2322+
// CHECK: return %[[CAST]] : index
2323+
func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
2324+
-> index {
2325+
%c3 = arith.constant 3 : index
2326+
%0 = tensor.reshape %arg0(%arg1)
2327+
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
2328+
%1 = tensor.dim %0, %c3 : tensor<*xf32>
2329+
return %1 : index
2330+
}
2331+
2332+
// -----
2333+
2334+
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2335+
// CHECK-LABEL: func @dim_of_reshape_for(
2336+
// CHECK: scf.for
2337+
// CHECK-NEXT: tensor.extract
2338+
// CHECK-NOT: tensor.dim
2339+
// CHECK-NOT: tensor.reshape
2340+
func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
2341+
%c0 = arith.constant 0 : index
2342+
%c1 = arith.constant 1 : index
2343+
%c4 = arith.constant 4 : index
2344+
2345+
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2346+
2347+
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
2348+
%2 = tensor.dim %0, %arg2 : tensor<*xf32>
2349+
%3 = arith.muli %arg3, %2 : index
2350+
scf.yield %3 : index
2351+
}
2352+
return %1 : index
2353+
}
2354+
2355+
// -----
2356+
2357+
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2358+
// CHECK-LABEL: func @dim_of_reshape_undominated(
2359+
// CHECK: arith.muli
2360+
// CHECK-NEXT: tensor.extract
2361+
// CHECK-NOT: tensor.dim
2362+
// CHECK-NOT: tensor.reshape
2363+
func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
2364+
%c4 = arith.constant 4 : index
2365+
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2366+
%0 = arith.muli %arg2, %c4 : index
2367+
%dim = tensor.dim %reshape, %0 : tensor<*xf32>
2368+
return %dim : index
2369+
}

0 commit comments

Comments
 (0)