Skip to content

Commit e7bb8dd

Browse files
[mlir][linalg][bufferize] Relax rules for extract_slice/insert_slice matching
The rules were too restrictive, causing out-of-place bufferization when the result of two ExtractSliceOp is fed into an InsertSliceOp. Differential Revision: https://reviews.llvm.org/D111861
1 parent 64591f2 commit e7bb8dd

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -1193,9 +1193,6 @@ bool BufferizationAliasInfo::areEquivalentExtractSliceOps(
11931193
return false;
11941194
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
11951195
return false;
1196-
// TODO: Is the following needed?
1197-
if (!equivalentInfo.isEquivalent(st.result(), sti.source()))
1198-
return false;
11991196
return true;
12001197
}
12011198

mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir

+26
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,32 @@ func @extract_slice_to_linalg_write_use(
278278
return %D, %E: tensor<4x4xf32>, tensor<4x4xf32>
279279
}
280280

281+
// -----
282+
283+
// CHECK-LABEL: func @insert_slice_double_extract_slice
284+
func @insert_slice_double_extract_slice(
285+
%s1: index, %s2: index, %s3: index, %s4: index, %A: tensor<8x6xf32>,
286+
%B: tensor<6x6xf32>, %C: tensor<30x20xf32> {linalg.inplaceable = true})
287+
-> tensor<30x20xf32> {
288+
// CHECK: tensor.extract_slice
289+
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
290+
%15 = tensor.extract_slice %C[%s3, %s4] [%s1, %s2] [1, 1] : tensor<30x20xf32> to tensor<?x?xf32>
291+
292+
// CHECK: linalg.matmul
293+
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
294+
%18 = linalg.matmul ins(%A, %B : tensor<8x6xf32>, tensor<6x6xf32>) outs(%15 : tensor<?x?xf32>) -> tensor<?x?xf32>
295+
296+
// CHECK: tensor.extract_slice
297+
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
298+
%19 = tensor.extract_slice %18[0, 0] [%s1, %s2] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
299+
300+
// CHECK: tensor.insert_slice
301+
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
302+
%20 = tensor.insert_slice %19 into %C[%s3, %s4] [%s1, %s2] [1, 1] : tensor<?x?xf32> into tensor<30x20xf32>
303+
304+
return %20 : tensor<30x20xf32>
305+
}
306+
281307
//===----------------------------------------------------------------------===//
282308
// Transitive cases
283309
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)