Skip to content

Commit e60dc8e

Browse files
committed
[flang][hlfir] Expand hlfir.assign's with scalar RHS.
Expanding hlfir.assign's with scalar RHS late in MLIR optimization pipeline allows LLVM to recognize most of them as simple memset loops. This is especially important for small size LHS arrays, because the assign loop nest may be completely unrolled enabling more value propagation. Reviewed By: tblah Differential Revision: https://reviews.llvm.org/D159151
1 parent 282da83 commit e60dc8e

File tree

2 files changed

+187
-0
lines changed

2 files changed

+187
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,65 @@ mlir::LogicalResult ElementalAssignBufferization::matchAndRewrite(
358358
return mlir::success();
359359
}
360360

361+
/// Expand hlfir.assign of a scalar RHS to array LHS into a loop nest
362+
/// of element-by-element assignments:
363+
/// hlfir.assign %cst to %0 : f32, !fir.ref<!fir.array<6x6xf32>>
364+
/// into:
365+
/// fir.do_loop %arg0 = %c1 to %c6 step %c1 unordered {
366+
/// fir.do_loop %arg1 = %c1 to %c6 step %c1 unordered {
367+
/// %1 = hlfir.designate %0 (%arg1, %arg0) :
368+
/// (!fir.ref<!fir.array<6x6xf32>>, index, index) -> !fir.ref<f32>
369+
/// hlfir.assign %cst to %1 : f32, !fir.ref<f32>
370+
/// }
371+
/// }
372+
class BroadcastAssignBufferization
373+
: public mlir::OpRewritePattern<hlfir::AssignOp> {
374+
private:
375+
public:
376+
using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
377+
378+
mlir::LogicalResult
379+
matchAndRewrite(hlfir::AssignOp assign,
380+
mlir::PatternRewriter &rewriter) const override;
381+
};
382+
383+
mlir::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
384+
hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
385+
if (assign.isAllocatableAssignment())
386+
return rewriter.notifyMatchFailure(assign, "AssignOp may imply allocation");
387+
388+
mlir::Value rhs = assign.getRhs();
389+
if (!fir::isa_trivial(rhs.getType()))
390+
return rewriter.notifyMatchFailure(
391+
assign, "AssignOp's RHS is not a trivial scalar");
392+
393+
hlfir::Entity lhs{assign.getLhs()};
394+
if (!lhs.isArray())
395+
return rewriter.notifyMatchFailure(assign,
396+
"AssignOp's LHS is not an array");
397+
398+
mlir::Type eleTy = lhs.getFortranElementType();
399+
if (!fir::isa_trivial(eleTy))
400+
return rewriter.notifyMatchFailure(
401+
assign, "AssignOp's LHS data type is not trivial");
402+
403+
mlir::Location loc = assign->getLoc();
404+
fir::FirOpBuilder builder(rewriter, assign.getOperation());
405+
builder.setInsertionPoint(assign);
406+
lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
407+
mlir::Value shape = hlfir::genShape(loc, builder, lhs);
408+
llvm::SmallVector<mlir::Value> extents =
409+
hlfir::getIndexExtents(loc, builder, shape);
410+
hlfir::LoopNest loopNest =
411+
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
412+
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
413+
auto arrayElement =
414+
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
415+
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
416+
rewriter.eraseOp(assign);
417+
return mlir::success();
418+
}
419+
361420
class OptimizedBufferizationPass
362421
: public hlfir::impl::OptimizedBufferizationBase<
363422
OptimizedBufferizationPass> {
@@ -371,7 +430,14 @@ class OptimizedBufferizationPass
371430
config.enableRegionSimplification = false;
372431

373432
mlir::RewritePatternSet patterns(context);
433+
// TODO: right now the patterns are non-conflicting,
434+
// but it might be better to run this pass on hlfir.assign
435+
// operations and decide which transformation to apply
436+
// at one place (e.g. we may use some heuristics and
437+
// choose different optimization strategies).
438+
// This requires small code reordering in ElementalAssignBufferization.
374439
patterns.insert<ElementalAssignBufferization>(context);
440+
patterns.insert<BroadcastAssignBufferization>(context);
375441

376442
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
377443
func, std::move(patterns), config))) {
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Test optimized bufferization for hlfir.assign with scalar RHS.
2+
// RUN: fir-opt --opt-bufferization %s | FileCheck %s
3+
4+
func.func @_QPtest1() {
5+
%cst = arith.constant 0.000000e+00 : f32
6+
%c11 = arith.constant 11 : index
7+
%c13 = arith.constant 13 : index
8+
%0 = fir.alloca !fir.array<11x13xf32> {bindc_name = "x", uniq_name = "_QFtest1Ex"}
9+
%1 = fir.shape %c11, %c13 : (index, index) -> !fir.shape<2>
10+
%2:2 = hlfir.declare %0(%1) {uniq_name = "_QFtest1Ex"} : (!fir.ref<!fir.array<11x13xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<11x13xf32>>, !fir.ref<!fir.array<11x13xf32>>)
11+
hlfir.assign %cst to %2#0 : f32, !fir.ref<!fir.array<11x13xf32>>
12+
return
13+
}
14+
// CHECK-LABEL: func.func @_QPtest1() {
15+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
16+
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
17+
// CHECK: %[[VAL_2:.*]] = arith.constant 11 : index
18+
// CHECK: %[[VAL_3:.*]] = arith.constant 13 : index
19+
// CHECK: %[[VAL_4:.*]] = fir.alloca !fir.array<11x13xf32> {bindc_name = "x", uniq_name = "_QFtest1Ex"}
20+
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_2]], %[[VAL_3]] : (index, index) -> !fir.shape<2>
21+
// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_4]](%[[VAL_5]]) {uniq_name = "_QFtest1Ex"} : (!fir.ref<!fir.array<11x13xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<11x13xf32>>, !fir.ref<!fir.array<11x13xf32>>)
22+
// CHECK: fir.do_loop %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_3]] step %[[VAL_0]] unordered {
23+
// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_0]] unordered {
24+
// CHECK: %[[VAL_9:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_8]], %[[VAL_7]]) : (!fir.ref<!fir.array<11x13xf32>>, index, index) -> !fir.ref<f32>
25+
// CHECK: hlfir.assign %[[VAL_1]] to %[[VAL_9]] : f32, !fir.ref<f32>
26+
// CHECK: }
27+
// CHECK: }
28+
// CHECK: return
29+
// CHECK: }
30+
31+
func.func @_QPtest2(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "x"}) {
32+
%c0_i32 = arith.constant 0 : i32
33+
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFtest2Ex"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
34+
hlfir.assign %c0_i32 to %0#0 : i32, !fir.box<!fir.array<?x?xi32>>
35+
return
36+
}
37+
// CHECK-LABEL: func.func @_QPtest2(
38+
// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "x"}) {
39+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
40+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
41+
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
42+
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFtest2Ex"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
43+
// CHECK: %[[VAL_5:.*]]:3 = fir.box_dims %[[VAL_4]]#0, %[[VAL_2]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
44+
// CHECK: %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_4]]#0, %[[VAL_1]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
45+
// CHECK: fir.do_loop %[[VAL_7:.*]] = %[[VAL_1]] to %[[VAL_6]]#1 step %[[VAL_1]] unordered {
46+
// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_1]] to %[[VAL_5]]#1 step %[[VAL_1]] unordered {
47+
// CHECK: %[[VAL_9:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_8]], %[[VAL_7]]) : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
48+
// CHECK: hlfir.assign %[[VAL_3]] to %[[VAL_9]] : i32, !fir.ref<i32>
49+
// CHECK: }
50+
// CHECK: }
51+
// CHECK: return
52+
// CHECK: }
53+
54+
func.func @_QPtest4(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>> {fir.bindc_name = "x"}) {
55+
%true = arith.constant true
56+
%0:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest4Ex"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>>)
57+
%1 = fir.convert %true : (i1) -> !fir.logical<4>
58+
%2 = fir.load %0#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>>
59+
hlfir.assign %1 to %2 : !fir.logical<4>, !fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>
60+
return
61+
}
62+
// CHECK-LABEL: func.func @_QPtest4(
63+
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>> {fir.bindc_name = "x"}) {
64+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
65+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
66+
// CHECK: %[[VAL_3:.*]] = arith.constant true
67+
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest4Ex"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>>)
68+
// CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_3]] : (i1) -> !fir.logical<4>
69+
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>>
70+
// CHECK: %[[VAL_7:.*]]:3 = fir.box_dims %[[VAL_6]], %[[VAL_2]] : (!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>, index) -> (index, index, index)
71+
// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_1]] to %[[VAL_7]]#1 step %[[VAL_1]] unordered {
72+
// CHECK: %[[VAL_9:.*]]:3 = fir.box_dims %[[VAL_6]], %[[VAL_2]] : (!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>, index) -> (index, index, index)
73+
// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]]#0, %[[VAL_1]] : index
74+
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_8]], %[[VAL_10]] : index
75+
// CHECK: %[[VAL_12:.*]] = hlfir.designate %[[VAL_6]] (%[[VAL_11]]) : (!fir.box<!fir.ptr<!fir.array<?x!fir.logical<4>>>>, index) -> !fir.ref<!fir.logical<4>>
76+
// CHECK: hlfir.assign %[[VAL_5]] to %[[VAL_12]] : !fir.logical<4>, !fir.ref<!fir.logical<4>>
77+
// CHECK: }
78+
// CHECK: return
79+
// CHECK: }
80+
81+
func.func @_QPtest3(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "x"}) {
82+
%c0_i32 = arith.constant 0 : i32
83+
%0:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest3Ex"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
84+
hlfir.assign %c0_i32 to %0#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
85+
return
86+
}
87+
// CHECK-LABEL: func.func @_QPtest3(
88+
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "x"}) {
89+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32
90+
// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtest3Ex"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
91+
// CHECK: hlfir.assign %[[VAL_1]] to %[[VAL_2]]#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
92+
// CHECK: return
93+
// CHECK: }
94+
95+
func.func @_QPtest5(%arg0: !fir.ref<!fir.array<77x!fir.complex<4>>> {fir.bindc_name = "x"}) {
96+
%cst = arith.constant 0.000000e+00 : f32
97+
%c77 = arith.constant 77 : index
98+
%0 = fir.shape %c77 : (index) -> !fir.shape<1>
99+
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFtest5Ex"} : (!fir.ref<!fir.array<77x!fir.complex<4>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<77x!fir.complex<4>>>, !fir.ref<!fir.array<77x!fir.complex<4>>>)
100+
%2 = fir.undefined !fir.complex<4>
101+
%3 = fir.insert_value %2, %cst, [0 : index] : (!fir.complex<4>, f32) -> !fir.complex<4>
102+
%4 = fir.insert_value %3, %cst, [1 : index] : (!fir.complex<4>, f32) -> !fir.complex<4>
103+
hlfir.assign %4 to %1#0 : !fir.complex<4>, !fir.ref<!fir.array<77x!fir.complex<4>>>
104+
return
105+
}
106+
// CHECK-LABEL: func.func @_QPtest5(
107+
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<77x!fir.complex<4>>> {fir.bindc_name = "x"}) {
108+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
109+
// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
110+
// CHECK: %[[VAL_3:.*]] = arith.constant 77 : index
111+
// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
112+
// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_4]]) {uniq_name = "_QFtest5Ex"} : (!fir.ref<!fir.array<77x!fir.complex<4>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<77x!fir.complex<4>>>, !fir.ref<!fir.array<77x!fir.complex<4>>>)
113+
// CHECK: %[[VAL_6:.*]] = fir.undefined !fir.complex<4>
114+
// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_2]], [0 : index] : (!fir.complex<4>, f32) -> !fir.complex<4>
115+
// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_2]], [1 : index] : (!fir.complex<4>, f32) -> !fir.complex<4>
116+
// CHECK: fir.do_loop %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_1]] unordered {
117+
// CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_9]]) : (!fir.ref<!fir.array<77x!fir.complex<4>>>, index) -> !fir.ref<!fir.complex<4>>
118+
// CHECK: hlfir.assign %[[VAL_8]] to %[[VAL_10]] : !fir.complex<4>, !fir.ref<!fir.complex<4>>
119+
// CHECK: }
120+
// CHECK: return
121+
// CHECK: }

0 commit comments

Comments
 (0)