Skip to content

Commit 9efdccb

Browse files
[mlir][memref] memref.subview: Verify result strides with rank reductions (#80158)
This is a follow-up on #79865. Result strides are now also verified if the `memref.subview` op has rank reductions.
1 parent 73e5466 commit 9efdccb

File tree

5 files changed

+42
-17
lines changed

5 files changed

+42
-17
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

+15-8
Original file line numberDiff line numberDiff line change
@@ -2756,17 +2756,26 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
27562756
}
27572757

27582758
/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2759-
/// static value).
2760-
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
2759+
/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2760+
/// marked as dropped in `droppedDims`.
2761+
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2762+
const llvm::SmallBitVector &droppedDims) {
2763+
assert(t1.getRank() == droppedDims.size() && "incorrect number of bits");
2764+
assert(t1.getRank() - t2.getRank() == droppedDims.count() &&
2765+
"incorrect number of dropped dims");
27612766
int64_t t1Offset, t2Offset;
27622767
SmallVector<int64_t> t1Strides, t2Strides;
27632768
auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
27642769
auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
27652770
if (failed(res1) || failed(res2))
27662771
return false;
2767-
for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
2768-
if (s1 != s2)
2772+
for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2773+
if (droppedDims[i])
2774+
continue;
2775+
if (t1Strides[i] != t2Strides[j])
27692776
return false;
2777+
++j;
2778+
}
27702779
return true;
27712780
}
27722781

@@ -2843,10 +2852,8 @@ LogicalResult SubViewOp::verify() {
28432852
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
28442853
*this, expectedType);
28452854

2846-
// Strides must match if there are no rank reductions.
2847-
// TODO: Verify strides when there are rank reductions. Strides are partially
2848-
// checked in `computeMemRefRankReductionMask`.
2849-
if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
2855+
// Strides must match.
2856+
if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
28502857
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
28512858
*this, expectedType);
28522859

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,25 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
144144
SmallVector<OpFoldResult> finalStrides;
145145
finalStrides.reserve(subRank);
146146

147+
#ifndef NDEBUG
148+
// Iteration variable for result dimensions of the subview op.
149+
int64_t j = 0;
150+
#endif // NDEBUG
147151
for (unsigned i = 0; i < sourceRank; ++i) {
148152
if (droppedDims.test(i))
149153
continue;
150154

151155
finalSizes.push_back(subSizes[i]);
152156
finalStrides.push_back(strides[i]);
153-
// TODO: Assert that the computed stride matches the respective stride of
154-
// the result type of the subview op (if both are static), once the verifier
155-
// of memref.subview verfies result strides correctly for ops with rank
156-
// reductions.
157+
#ifndef NDEBUG
158+
// Assert that the computed stride matches the stride of the result type of
159+
// the subview op (if both are static).
160+
std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
161+
if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
162+
assert(*computedStride == resultStrides[j] &&
163+
"mismatch between computed stride and result type stride");
164+
++j;
165+
#endif // NDEBUG
157166
}
158167
assert(finalSizes.size() == subRank &&
159168
"Should have populated all the values at this point");

mlir/test/Dialect/MemRef/canonicalize.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
6262
// -----
6363

6464
func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
65-
%arg2 : index) -> memref<?x?xf32, strided<[?, 1], offset: ?>>
65+
%arg2 : index) -> memref<?x?xf32, strided<[?, ?], offset: ?>>
6666
{
6767
%c0 = arith.constant 0 : index
6868
%c1 = arith.constant 1 : index
6969
%c4 = arith.constant 4 : index
70-
%0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
71-
return %0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
70+
%0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
71+
return %0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
7272
}
7373
// CHECK-LABEL: func @rank_reducing_subview_canonicalize
7474
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32>

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -613,9 +613,9 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
613613
{
614614
%0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
615615
: memref<?x?x?xf32>
616-
to memref<?xf32, strided<[1], offset: ?>>
616+
to memref<?xf32, strided<[?], offset: ?>>
617617
%1 = memref.subview %0[6] [1] [1]
618-
: memref<?xf32, strided<[1], offset: ?>>
618+
: memref<?xf32, strided<[?], offset: ?>>
619619
to memref<f32, strided<[], offset: ?>>
620620
return %1 : memref<f32, strided<[], offset: ?>>
621621
}

mlir/test/Dialect/MemRef/invalid.mlir

+9
Original file line numberDiff line numberDiff line change
@@ -1082,3 +1082,12 @@ func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
10821082
: memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
10831083
return
10841084
}
1085+
1086+
// -----
1087+
1088+
func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>) {
1089+
// expected-error @below{{expected result type to be 'memref<7x11x1x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
1090+
%subview = memref.subview %m[0, 0, 0, 0] [7, 11, 1, 4444] [1, 2, 1, 1]
1091+
: memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
1092+
return
1093+
}

0 commit comments

Comments
 (0)