@@ -2756,17 +2756,26 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2756
2756
}
2757
2757
2758
2758
// / Return true if `t1` and `t2` have equal strides (both dynamic or of same
2759
- // / static value).
2760
- static bool haveCompatibleStrides (MemRefType t1, MemRefType t2) {
2759
+ // / static value). Dimensions of `t1` may be dropped in `t2`; these must be
2760
+ // / marked as dropped in `droppedDims`.
2761
+ static bool haveCompatibleStrides (MemRefType t1, MemRefType t2,
2762
+ const llvm::SmallBitVector &droppedDims) {
2763
+ assert (t1.getRank () == droppedDims.size () && " incorrect number of bits" );
2764
+ assert (t1.getRank () - t2.getRank () == droppedDims.count () &&
2765
+ " incorrect number of dropped dims" );
2761
2766
int64_t t1Offset, t2Offset;
2762
2767
SmallVector<int64_t > t1Strides, t2Strides;
2763
2768
auto res1 = getStridesAndOffset (t1, t1Strides, t1Offset);
2764
2769
auto res2 = getStridesAndOffset (t2, t2Strides, t2Offset);
2765
2770
if (failed (res1) || failed (res2))
2766
2771
return false ;
2767
- for (auto [s1, s2] : llvm::zip_equal (t1Strides, t2Strides))
2768
- if (s1 != s2)
2772
+ for (int64_t i = 0 , j = 0 , e = t1.getRank (); i < e; ++i) {
2773
+ if (droppedDims[i])
2774
+ continue ;
2775
+ if (t1Strides[i] != t2Strides[j])
2769
2776
return false ;
2777
+ ++j;
2778
+ }
2770
2779
return true ;
2771
2780
}
2772
2781
@@ -2843,10 +2852,8 @@ LogicalResult SubViewOp::verify() {
2843
2852
return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2844
2853
*this , expectedType);
2845
2854
2846
- // Strides must match if there are no rank reductions.
2847
- // TODO: Verify strides when there are rank reductions. Strides are partially
2848
- // checked in `computeMemRefRankReductionMask`.
2849
- if (unusedDims->none () && !haveCompatibleStrides (expectedType, subViewType))
2855
+ // Strides must match.
2856
+ if (!haveCompatibleStrides (expectedType, subViewType, *unusedDims))
2850
2857
return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2851
2858
*this , expectedType);
2852
2859
0 commit comments