@@ -41,27 +41,27 @@ func.func @contiguous_inner_most_scalable_inner_dim(%in: memref<1x1x8x1xf32, str
41
41
// Same as the top example within this split, but the trailing unit dim was
42
42
// replaced with a dyn dim - not supported
43
43
44
- func.func @non_unit_trailing_dim (%in: memref <1 x1 x8 x?xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x1 xf32 >{
44
+ func.func @negative_dynamic_trailing_dim (%in: memref <1 x1 x8 x?xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x1 xf32 >{
45
45
%c0 = arith.constant 0 : index
46
46
%cst = arith.constant 0.0 : f32
47
47
%0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <1 x1 x8 x?xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x8 x1 xf32 >
48
48
return %0 : vector <1 x8 x1 xf32 >
49
49
}
50
50
51
- // CHECK-LABEL: func @non_unit_trailing_dim
51
+ // CHECK-LABEL: func @negative_dynamic_trailing_dim
52
52
// CHECK-NOT: memref.subview
53
53
// CHECK-NOT: vector.shape_cast
54
54
55
- // Same as the top example within this split, but with a scalable unit dim in
56
- // the output vector - not supported (scalable 1 is _not_ a unit dimension).
55
+ // Same as the top example within this split, but with a " scalable unit" dim in
56
+ // the output vector - not supported (scalable 1, [1], is _not_ a unit dimension).
57
57
58
- func.func @negative_scalable_unit_dim (%in: memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x[1 ]xf32 >{
58
+ func.func @negative_scalable_one_trailing_dim (%in: memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x[1 ]xf32 >{
59
59
%c0 = arith.constant 0 : index
60
60
%cst = arith.constant 0.0 : f32
61
61
%0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x8 x[1 ]xf32 >
62
62
return %0 : vector <1 x8 x[1 ]xf32 >
63
63
}
64
- // CHECK-LABEL: func @negative_scalable_unit_dim
64
+ // CHECK-LABEL: func @negative_scalable_one_trailing_dim
65
65
// CHECK-NOT: memref.subview
66
66
// CHECK-NOT: vector.shape_cast
67
67
@@ -254,14 +254,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector<
254
254
// 2. vector.transfer_write
255
255
//-----------------------------------------------------------------------------
256
256
257
- func.func @drop_two_inner_most_dim_for_transfer_write (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
257
+ func.func @drop_two_inner_most_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
258
258
%c0 = arith.constant 0 : index
259
259
vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
260
260
{in_bounds = [true , true , true , true , true ]}
261
261
: vector <1 x16 x16 x1 x1 xf32 >, memref <1 x512 x16 x1 x1 xf32 >
262
262
return
263
263
}
264
- // CHECK: func.func @drop_two_inner_most_dim_for_transfer_write
264
+ // CHECK: func.func @drop_two_inner_most_dim
265
265
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
266
266
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
267
267
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -272,16 +272,67 @@ func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1
272
272
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
273
273
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
274
274
275
+ // Same as the top example within this split, but with the inner vector
276
+ // dim scalable. Note that this example only makes sense when "16 = [16]" (i.e.
277
+ // vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
278
+
279
+ func.func @drop_two_inner_most_dim_scalable_inner_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x[16 ]x1 x1 xf32 >, %arg2: index ) {
280
+ %c0 = arith.constant 0 : index
281
+ vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
282
+ {in_bounds = [true , true , true , true , true ]}
283
+ : vector <1 x16 x[16 ]x1 x1 xf32 >, memref <1 x512 x16 x1 x1 xf32 >
284
+ return
285
+ }
286
+ // CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim
287
+ // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
288
+ // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
289
+ // CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
290
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
291
+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
292
+ // CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32, strided<[8192, 16, 1]>>
293
+ // CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x[16]x1x1xf32> to vector<1x16x[16]xf32>
294
+ // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
295
+ // CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
296
+
297
+ // Same as the top example within this split, but the trailing unit dim was
298
+ // replaced with a dyn dim - not supported
299
+
300
+ func.func @negative_dynamic_trailing_dim (%arg0: memref <1 x512 x16 x1 x?xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
301
+ %c0 = arith.constant 0 : index
302
+ vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
303
+ {in_bounds = [true , true , true , true , true ]}
304
+ : vector <1 x16 x16 x1 x1 xf32 >, memref <1 x512 x16 x1 x?xf32 >
305
+ return
306
+ }
307
+ // CHECK: func.func @negative_dynamic_trailing_dim
308
+ // CHECK-NOT: memref.subview
309
+ // CHECK-NOT: vector.shape_cast
310
+
311
+ // Same as the top example within this split, but with a "scalable unit" dim in
312
+ // the input vector - not supported (scalable 1, [1], is _not_ a unit dimension).
313
+
314
+ func.func @negative_scalable_one_trailing_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x[1 ]xf32 >, %arg2: index ) {
315
+ %c0 = arith.constant 0 : index
316
+ vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
317
+ {in_bounds = [true , true , true , true , true ]}
318
+ : vector <1 x16 x16 x1 x[1 ]xf32 >, memref <1 x512 x16 x1 x1 xf32 >
319
+ return
320
+ }
321
+
322
+ // CHECK: func.func @negative_scalable_one_trailing_dim
323
+ // CHECK-NOT: memref.subview
324
+ // CHECK-NOT: vector.shape_cast
325
+
275
326
// -----
276
327
277
- func.func @drop_inner_most_dim_for_transfer_write (%arg0: memref <1 x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
328
+ func.func @drop_inner_most_dim (%arg0: memref <1 x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
278
329
%c0 = arith.constant 0 : index
279
330
vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 ]
280
331
{in_bounds = [true , true , true , true ]}
281
332
: vector <1 x16 x16 x1 xf32 >, memref <1 x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>
282
333
return
283
334
}
284
- // CHECK: func.func @drop_inner_most_dim_for_transfer_write
335
+ // CHECK: func.func @drop_inner_most_dim
285
336
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
286
337
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
287
338
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -294,14 +345,14 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
294
345
295
346
// -----
296
347
297
- func.func @outer_dyn_drop_inner_most_dim_for_transfer_write (%arg0: memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
348
+ func.func @outer_dyn_drop_inner_most_dim (%arg0: memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
298
349
%c0 = arith.constant 0 : index
299
350
vector.transfer_write %arg1 , %arg0 [%arg2 , %c0 , %c0 , %c0 ]
300
351
{in_bounds = [true , true , true , true ]}
301
352
: vector <1 x16 x16 x1 xf32 >, memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>
302
353
return
303
354
}
304
- // CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
355
+ // CHECK: func.func @outer_dyn_drop_inner_most_dim
305
356
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
306
357
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
307
358
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -325,30 +376,3 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
325
376
// The inner most unit dims can not be dropped if the strides are not ones.
326
377
// CHECK: func.func @non_unit_strides
327
378
// CHECK-NOT: memref.subview
328
-
329
- // -----
330
-
331
- func.func @leading_scalable_dimension_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <[4 ]x1 xf32 >) {
332
- %c0 = arith.constant 0 : index
333
- vector.transfer_write %vec , %dest [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[4 ]x1 xf32 >, memref <24 x1 xf32 >
334
- return
335
- }
336
- // CHECK: func.func @leading_scalable_dimension_transfer_write
337
- // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
338
- // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
339
- // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
340
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
341
- // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
342
-
343
- // -----
344
-
345
- // Negative test: [1] (scalable 1) is _not_ a unit dimension.
346
- func.func @trailing_scalable_one_dim_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <4 x[1 ]xf32 >, %index: index ) {
347
- %c0 = arith.constant 0 : index
348
- vector.transfer_write %vec , %dest [%index , %c0 ] {in_bounds = [true , true ]} : vector <4 x[1 ]xf32 >, memref <24 x1 xf32 >
349
- return
350
- }
351
- // CHECK: func.func @trailing_scalable_one_dim_transfer_write
352
- // CHECK-NOT: vector.shape_cast
353
- // CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
354
- // CHECK-NOT: vector.shape_cast
0 commit comments