Skip to content

Commit c65fb32

Browse files
authored
[mlir][vector] Update tests for collapse 3/n (nfc) (llvm#94906)
The main goal of this PR (and subsequent PRs), is to add more tests with scalable vectors to: * vector-transfer-collapse-inner-most-dims.mlir There's quite a few cases to consider, hence this is split into multiple PRs. In this PR, the very first test for `vector.transfer_write` is complemented with all the possible combinations: * scalable (rather than fixed) unit trailing dim, * dynamic (rather than static) trailing dim in the source memref. To this end, the following tests: * `@leading_scalable_dimension_transfer_write` `@trailing_scalable_one_dim_transfer_write` are replaced with: * `@drop_two_inner_most_dim_scalable_inner_dim` and `@negative_scalable_unit_dim`, respectively. In addition: * "_for_transfer_write" is removed from function names (to reduce noise). In addition, to maintain consistency between the tests for `xfer_read` and `xfer_write`, 2 negative tests for `xfer_read` are also renamed. This is to follow the suggestion made during the review of this PR. Extra comments in "VectorTransforms.cpp" are added to better document the limitations related to scalable vectors and which tests added here excercise. This is a follow-up for: llvm#94490 and llvm#94604 NOTE: This PR is limited to tests for `vector.transfer_write`.
1 parent 242cc20 commit c65fb32

File tree

2 files changed

+78
-44
lines changed

2 files changed

+78
-44
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,11 +1225,19 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
12251225

12261226
/// Returns the number of dims can be folded away from transfer ops. It returns
12271227
/// a failure if it can not determine the number of dims to be folded.
1228-
/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
1229-
/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
1230-
/// can be dropped by memref.subview ops.
1231-
/// Example 2: it returns "1" if `srcType` is the same memref type with
1232-
/// [8192, 16, 8, 1] strides.
1228+
///
1229+
/// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1230+
/// `vectorType` is vector<16x16x1x1xf32>
1231+
/// (there two inner most dims can be dropped by memref.subview ops)
1232+
///
1233+
/// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1234+
/// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1235+
/// (only the inner most unit dim of `srcType` can be dropped)
1236+
///
1237+
/// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1238+
/// `vectorType` is vector<16x16x1x[1]xf32>
1239+
/// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1240+
/// unit")
12331241
static FailureOr<size_t>
12341242
getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
12351243
SmallVector<int64_t> srcStrides;
@@ -1351,6 +1359,8 @@ class DropInnerMostUnitDimsTransferRead
13511359
/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
13521360
/// {in_bounds = [true, true, true]}
13531361
/// : vector<1x16x16xf32>, memref<1x512x16xf32>
1362+
///
1363+
/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
13541364
class DropInnerMostUnitDimsTransferWrite
13551365
: public OpRewritePattern<vector::TransferWriteOp> {
13561366
using OpRewritePattern::OpRewritePattern;

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,27 @@ func.func @contiguous_inner_most_scalable_inner_dim(%in: memref<1x1x8x1xf32, str
4141
// Same as the top example within this split, but the trailing unit dim was
4242
// replaced with a dyn dim - not supported
4343

44-
func.func @non_unit_trailing_dim(%in: memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
44+
func.func @negative_dynamic_trailing_dim(%in: memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
4545
%c0 = arith.constant 0 : index
4646
%cst = arith.constant 0.0 : f32
4747
%0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
4848
return %0 : vector<1x8x1xf32>
4949
}
5050

51-
// CHECK-LABEL: func @non_unit_trailing_dim
51+
// CHECK-LABEL: func @negative_dynamic_trailing_dim
5252
// CHECK-NOT: memref.subview
5353
// CHECK-NOT: vector.shape_cast
5454

55-
// Same as the top example within this split, but with a scalable unit dim in
56-
// the output vector - not supported (scalable 1 is _not_ a unit dimension).
55+
// Same as the top example within this split, but with a "scalable unit" dim in
56+
// the output vector - not supported (scalable 1, [1], is _not_ a unit dimension).
5757

58-
func.func @negative_scalable_unit_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x[1]xf32>{
58+
func.func @negative_scalable_one_trailing_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x[1]xf32>{
5959
%c0 = arith.constant 0 : index
6060
%cst = arith.constant 0.0 : f32
6161
%0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x[1]xf32>
6262
return %0 : vector<1x8x[1]xf32>
6363
}
64-
// CHECK-LABEL: func @negative_scalable_unit_dim
64+
// CHECK-LABEL: func @negative_scalable_one_trailing_dim
6565
// CHECK-NOT: memref.subview
6666
// CHECK-NOT: vector.shape_cast
6767

@@ -254,14 +254,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector<
254254
// 2. vector.transfer_write
255255
//-----------------------------------------------------------------------------
256256

257-
func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
257+
func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
258258
%c0 = arith.constant 0 : index
259259
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
260260
{in_bounds = [true, true, true, true, true]}
261261
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
262262
return
263263
}
264-
// CHECK: func.func @drop_two_inner_most_dim_for_transfer_write
264+
// CHECK: func.func @drop_two_inner_most_dim
265265
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
266266
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
267267
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -272,16 +272,67 @@ func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1
272272
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
273273
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
274274

275+
// Same as the top example within this split, but with the inner vector
276+
// dim scalable. Note that this example only makes sense when "16 = [16]" (i.e.
277+
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
278+
279+
func.func @drop_two_inner_most_dim_scalable_inner_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x[16]x1x1xf32>, %arg2: index) {
280+
%c0 = arith.constant 0 : index
281+
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
282+
{in_bounds = [true, true, true, true, true]}
283+
: vector<1x16x[16]x1x1xf32>, memref<1x512x16x1x1xf32>
284+
return
285+
}
286+
// CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim
287+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
288+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
289+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
290+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
291+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
292+
// CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32, strided<[8192, 16, 1]>>
293+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x[16]x1x1xf32> to vector<1x16x[16]xf32>
294+
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
295+
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
296+
297+
// Same as the top example within this split, but the trailing unit dim was
298+
// replaced with a dyn dim - not supported
299+
300+
func.func @negative_dynamic_trailing_dim(%arg0: memref<1x512x16x1x?xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
301+
%c0 = arith.constant 0 : index
302+
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
303+
{in_bounds = [true, true, true, true, true]}
304+
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x?xf32>
305+
return
306+
}
307+
// CHECK: func.func @negative_dynamic_trailing_dim
308+
// CHECK-NOT: memref.subview
309+
// CHECK-NOT: vector.shape_cast
310+
311+
// Same as the top example within this split, but with a "scalable unit" dim in
312+
// the input vector - not supported (scalable 1, [1], is _not_ a unit dimension).
313+
314+
func.func @negative_scalable_one_trailing_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x[1]xf32>, %arg2: index) {
315+
%c0 = arith.constant 0 : index
316+
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
317+
{in_bounds = [true, true, true, true, true]}
318+
: vector<1x16x16x1x[1]xf32>, memref<1x512x16x1x1xf32>
319+
return
320+
}
321+
322+
// CHECK: func.func @negative_scalable_one_trailing_dim
323+
// CHECK-NOT: memref.subview
324+
// CHECK-NOT: vector.shape_cast
325+
275326
// -----
276327

277-
func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
328+
func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
278329
%c0 = arith.constant 0 : index
279330
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
280331
{in_bounds = [true, true, true, true]}
281332
: vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
282333
return
283334
}
284-
// CHECK: func.func @drop_inner_most_dim_for_transfer_write
335+
// CHECK: func.func @drop_inner_most_dim
285336
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
286337
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
287338
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -294,14 +345,14 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
294345

295346
// -----
296347

297-
func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
348+
func.func @outer_dyn_drop_inner_most_dim(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
298349
%c0 = arith.constant 0 : index
299350
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
300351
{in_bounds = [true, true, true, true]}
301352
: vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
302353
return
303354
}
304-
// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
355+
// CHECK: func.func @outer_dyn_drop_inner_most_dim
305356
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
306357
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
307358
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -325,30 +376,3 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
325376
// The inner most unit dims can not be dropped if the strides are not ones.
326377
// CHECK: func.func @non_unit_strides
327378
// CHECK-NOT: memref.subview
328-
329-
// -----
330-
331-
func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>) {
332-
%c0 = arith.constant 0 : index
333-
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32>
334-
return
335-
}
336-
// CHECK: func.func @leading_scalable_dimension_transfer_write
337-
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
338-
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
339-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
340-
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
341-
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
342-
343-
// -----
344-
345-
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
346-
func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) {
347-
%c0 = arith.constant 0 : index
348-
vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>, memref<24x1xf32>
349-
return
350-
}
351-
// CHECK: func.func @trailing_scalable_one_dim_transfer_write
352-
// CHECK-NOT: vector.shape_cast
353-
// CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
354-
// CHECK-NOT: vector.shape_cast

0 commit comments

Comments
 (0)