Skip to content

Commit a99e06a

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Avoid generating illegal operations during elementwise fusion.
In some cases, fusion can produce illegal operations if after fusion the range of some of the loops cannot be computed from shapes of its operands. Check for this case and abort the fusion if this happens. Differential Revision: https://reviews.llvm.org/D117602
1 parent e6de53b commit a99e06a

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,13 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
318318
consumer.iterator_types(),
319319
/*doc=*/nullptr,
320320
/*library_call=*/nullptr);
321+
if (!fusedOp.getShapesToLoopsMap()) {
322+
// Fused op has invalid indexing maps. Typically this means something is off
323+
// in the input, but going ahead here would result in verification errors.
324+
// So cleanup and abort.
325+
rewriter.eraseOp(fusedOp);
326+
return llvm::None;
327+
}
321328

322329
// Construct an AffineMap from consumer loops to producer loops.
323330
// consumer loop -> tensor index

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,3 +945,33 @@ func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> ten
945945
} -> tensor<?xf32>
946946
return %8 : tensor<?xf32>
947947
}
948+
949+
// -----
950+
951+
func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
952+
%c1_i32 = arith.constant 1 : i32
953+
%0 = linalg.generic {
954+
indexing_maps = [affine_map<(d0) -> (d0)>],
955+
iterator_types = ["parallel"]}
956+
outs(%arg0 : tensor<5000xi64>) {
957+
^bb0(%arg3: i64): // no predecessors
958+
%22 = linalg.index 0 : index
959+
%23 = arith.index_cast %22 : index to i64
960+
linalg.yield %23 : i64
961+
} -> tensor<5000xi64>
962+
%1 = linalg.init_tensor [5000] : tensor<5000xi32>
963+
%2 = linalg.generic {
964+
indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>],
965+
iterator_types = ["parallel", "parallel"]}
966+
ins(%0 : tensor<5000xi64>) outs(%1 : tensor<5000xi32>) {
967+
^bb0(%arg3: i64, %arg5: i32): // no predecessors
968+
%22 = arith.index_cast %arg3 : i64 to index
969+
%23 = tensor.extract %arg1[%22] : tensor<5000xi32>
970+
linalg.yield %23 : i32
971+
} -> tensor<5000xi32>
972+
return %2 : tensor<5000xi32>
973+
}
974+
// CHECK-LABEL: func @illegal_fusion(
975+
// CHECK: %[[PRODUCER:.+]] = linalg.generic
976+
// CHECK: linalg.generic
977+
// CHECK-SAME: ins(%[[PRODUCER]]

0 commit comments

Comments
 (0)