Skip to content

Commit 1e3a021

Browse files
committed
[mlir][scf] Update IfOp to have getInvocationBounds
This allows `scf.if` to be used by Control-Flow sink. Depends on D115088 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D115089
1 parent aa53d07 commit 1e3a021

File tree

4 files changed

+84
-3
lines changed

4 files changed

+84
-3
lines changed

mlir/include/mlir/Dialect/SCF/SCFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,9 @@ def ForOp : SCF_Op<"for",
315315
}
316316

317317
def IfOp : SCF_Op<"if",
318-
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
318+
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
319+
["getNumRegionInvocations",
320+
"getRegionInvocationBounds"]>,
319321
SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects,
320322
NoRegionArguments]> {
321323
let summary = "if-then-else operation";

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
140140
of invocations cannot be statically determined, then it will not have a
141141
value (i.e., it is set to `llvm::None`).
142142

143-
`operands` is a set of optional attributes that either correspond to a
144-
constant values for each operand of this operation, or null if that
143+
`operands` is a set of optional attributes that either correspond to
144+
constant values for each operand of this operation or null if that
145145
operand is not a constant.
146+
147+
This method may be called speculatively on operations where the provided
148+
operands are not necessarily the same as the operation's current
149+
operands. This may occur in analyses that wish to determine "what would
150+
be the region invocations if these were the operands?"
146151
}],
147152
"void", "getRegionInvocationBounds",
148153
(ins "::mlir::ArrayRef<::mlir::Attribute>":$operands,

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,20 @@ LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
11881188
return success();
11891189
}
11901190

1191+
void IfOp::getRegionInvocationBounds(
1192+
ArrayRef<Attribute> operands,
1193+
SmallVectorImpl<InvocationBounds> &invocationBounds) {
1194+
if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
1195+
// If the condition is known, then one region is known to be executed once
1196+
// and the other zero times.
1197+
invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1198+
invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1199+
} else {
1200+
// Non-constant condition. Each region may be executed 0 or 1 times.
1201+
invocationBounds.assign(2, {0, 1});
1202+
}
1203+
}
1204+
11911205
namespace {
11921206
// Pattern to remove unused IfOp results.
11931207
struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s
2+
3+
// CHECK-LABEL: @test_scf_if_sink
4+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32)
5+
// CHECK: %[[V0:.*]] = scf.if %[[ARG0]]
6+
// CHECK: %[[V1:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
7+
// CHECK: scf.yield %[[V1]]
8+
// CHECK: else
9+
// CHECK: %[[V1:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
10+
// CHECK: scf.yield %[[V1]]
11+
// CHECK: return %[[V0]]
12+
func @test_scf_if_sink(%arg0: i1, %arg1: i32) -> i32 {
13+
%0 = arith.addi %arg1, %arg1 : i32
14+
%1 = arith.muli %arg1, %arg1 : i32
15+
%result = scf.if %arg0 -> i32 {
16+
scf.yield %0 : i32
17+
} else {
18+
scf.yield %1 : i32
19+
}
20+
return %result : i32
21+
}
22+
23+
// -----
24+
25+
func private @consume(i32) -> ()
26+
27+
// CHECK-LABEL: @test_scf_if_then_only_sink
28+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32)
29+
// CHECK: scf.if %[[ARG0]]
30+
// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
31+
// CHECK: call @consume(%[[V0]])
32+
func @test_scf_if_then_only_sink(%arg0: i1, %arg1: i32) {
33+
%0 = arith.addi %arg1, %arg1 : i32
34+
scf.if %arg0 {
35+
call @consume(%0) : (i32) -> ()
36+
scf.yield
37+
}
38+
return
39+
}
40+
41+
// -----
42+
43+
func private @consume(i32) -> ()
44+
45+
// CHECK-LABEL: @test_scf_if_double_sink
46+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32)
47+
// CHECK: scf.if %[[ARG0]]
48+
// CHECK: scf.if %[[ARG0]]
49+
// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
50+
// CHECK: call @consume(%[[V0]])
51+
func @test_scf_if_double_sink(%arg0: i1, %arg1: i32) {
52+
%0 = arith.addi %arg1, %arg1 : i32
53+
scf.if %arg0 {
54+
scf.if %arg0 {
55+
call @consume(%0) : (i32) -> ()
56+
scf.yield
57+
}
58+
}
59+
return
60+
}

0 commit comments

Comments
 (0)