Skip to content

Commit d8a6a69

Browse files
committed
[MLIR][SCF] Place hoisted scf.if->select prior to the remaining if
This patch slightly updates the behavior of scf.if->select to place any hoisted select statements prior to the remaining scf.if body. This allows better composition with other canonicalization passes, such as scf.if nested merging. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122027
1 parent fc35376 commit d8a6a69

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,7 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
13661366

13671367
SmallVector<Value> trueYields;
13681368
SmallVector<Value> falseYields;
1369+
rewriter.setInsertionPoint(replacement);
13691370
for (const auto &it :
13701371
llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
13711372
Value trueVal = std::get<0>(it.value());

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,10 @@ func @to_select_with_body(%cond: i1) -> index {
321321
// CHECK-LABEL: func @to_select_with_body
322322
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
323323
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
324+
// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
324325
// CHECK: scf.if {{.*}} {
325326
// CHECK: "test.op"() : () -> ()
326327
// CHECK: }
327-
// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
328328
// CHECK: return [[V0]] : index
329329
// -----
330330

@@ -556,10 +556,10 @@ func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
556556
// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
557557
// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
558558
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
559+
// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]]
559560
// CHECK: scf.if %[[COND]]
560561
// CHECK: "test.run"() : () -> ()
561562
// CHECK: }
562-
// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]]
563563
// CHECK: return %[[RES]]
564564
%0 = "test.op"() : () -> (i32)
565565
%1 = "test.op1"() : () -> (i32)
@@ -933,14 +933,14 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
933933
return %res#0, %res#1 : i32, i1
934934
}
935935
// CHECK-NEXT: %true = arith.constant true
936+
// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1
936937
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
937938
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
938939
// CHECK-NEXT: scf.yield %[[sv1]] : i32
939940
// CHECK-NEXT: } else {
940941
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
941942
// CHECK-NEXT: scf.yield %[[sv2]] : i32
942943
// CHECK-NEXT: }
943-
// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1
944944
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
945945

946946
// -----

0 commit comments

Comments
 (0)