@@ -417,6 +417,102 @@ mlir::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
417
417
return mlir::success ();
418
418
}
419
419
420
+ // / Expand hlfir.assign of array RHS to array LHS into a loop nest
421
+ // / of element-by-element assignments:
422
+ // / hlfir.assign %4 to %5 : !fir.ref<!fir.array<3x3xf32>>,
423
+ // / !fir.ref<!fir.array<3x3xf32>>
424
+ // / into:
425
+ // / fir.do_loop %arg1 = %c1 to %c3 step %c1 unordered {
426
+ // / fir.do_loop %arg2 = %c1 to %c3 step %c1 unordered {
427
+ // / %6 = hlfir.designate %4 (%arg2, %arg1) :
428
+ // / (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32>
429
+ // / %7 = fir.load %6 : !fir.ref<f32>
430
+ // / %8 = hlfir.designate %5 (%arg2, %arg1) :
431
+ // / (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32>
432
+ // / hlfir.assign %7 to %8 : f32, !fir.ref<f32>
433
+ // / }
434
+ // / }
435
+ // /
436
+ // / The transformation is correct only when LHS and RHS do not alias.
437
+ // / This transformation does not support runtime checking for
438
+ // / non-conforming LHS/RHS arrays' shapes currently.
439
+ class VariableAssignBufferization
440
+ : public mlir::OpRewritePattern<hlfir::AssignOp> {
441
+ private:
442
+ public:
443
+ using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
444
+
445
+ mlir::LogicalResult
446
+ matchAndRewrite (hlfir::AssignOp assign,
447
+ mlir::PatternRewriter &rewriter) const override ;
448
+ };
449
+
450
+ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite (
451
+ hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
452
+ if (assign.isAllocatableAssignment ())
453
+ return rewriter.notifyMatchFailure (assign, " AssignOp may imply allocation" );
454
+
455
+ hlfir::Entity rhs{assign.getRhs ()};
456
+ // TODO: ExprType check is here to avoid conflicts with
457
+ // ElementalAssignBufferization pattern. We need to combine
458
+ // these matchers into a single one that applies to AssignOp.
459
+ if (rhs.getType ().isa <hlfir::ExprType>())
460
+ return rewriter.notifyMatchFailure (assign, " RHS is not in memory" );
461
+
462
+ if (!rhs.isArray ())
463
+ return rewriter.notifyMatchFailure (assign,
464
+ " AssignOp's RHS is not an array" );
465
+
466
+ mlir::Type rhsEleTy = rhs.getFortranElementType ();
467
+ if (!fir::isa_trivial (rhsEleTy))
468
+ return rewriter.notifyMatchFailure (
469
+ assign, " AssignOp's RHS data type is not trivial" );
470
+
471
+ hlfir::Entity lhs{assign.getLhs ()};
472
+ if (!lhs.isArray ())
473
+ return rewriter.notifyMatchFailure (assign,
474
+ " AssignOp's LHS is not an array" );
475
+
476
+ mlir::Type lhsEleTy = lhs.getFortranElementType ();
477
+ if (!fir::isa_trivial (lhsEleTy))
478
+ return rewriter.notifyMatchFailure (
479
+ assign, " AssignOp's LHS data type is not trivial" );
480
+
481
+ if (lhsEleTy != rhsEleTy)
482
+ return rewriter.notifyMatchFailure (assign,
483
+ " RHS/LHS element types mismatch" );
484
+
485
+ fir::AliasAnalysis aliasAnalysis;
486
+ mlir::AliasResult aliasRes = aliasAnalysis.alias (lhs, rhs);
487
+ if (!aliasRes.isNo ()) {
488
+ LLVM_DEBUG (llvm::dbgs () << " VariableAssignBufferization:\n "
489
+ << " \t LHS: " << lhs << " \n "
490
+ << " \t RHS: " << rhs << " \n "
491
+ << " \t ALIAS: " << aliasRes << " \n " );
492
+ return rewriter.notifyMatchFailure (assign, " RHS/LHS may alias" );
493
+ }
494
+
495
+ mlir::Location loc = assign->getLoc ();
496
+ fir::FirOpBuilder builder (rewriter, assign.getOperation ());
497
+ builder.setInsertionPoint (assign);
498
+ rhs = hlfir::derefPointersAndAllocatables (loc, builder, rhs);
499
+ lhs = hlfir::derefPointersAndAllocatables (loc, builder, lhs);
500
+ mlir::Value shape = hlfir::genShape (loc, builder, lhs);
501
+ llvm::SmallVector<mlir::Value> extents =
502
+ hlfir::getIndexExtents (loc, builder, shape);
503
+ hlfir::LoopNest loopNest =
504
+ hlfir::genLoopNest (loc, builder, extents, /* isUnordered=*/ true );
505
+ builder.setInsertionPointToStart (loopNest.innerLoop .getBody ());
506
+ auto rhsArrayElement =
507
+ hlfir::getElementAt (loc, builder, rhs, loopNest.oneBasedIndices );
508
+ rhsArrayElement = hlfir::loadTrivialScalar (loc, builder, rhsArrayElement);
509
+ auto lhsArrayElement =
510
+ hlfir::getElementAt (loc, builder, lhs, loopNest.oneBasedIndices );
511
+ builder.create <hlfir::AssignOp>(loc, rhsArrayElement, lhsArrayElement);
512
+ rewriter.eraseOp (assign);
513
+ return mlir::success ();
514
+ }
515
+
420
516
class OptimizedBufferizationPass
421
517
: public hlfir::impl::OptimizedBufferizationBase<
422
518
OptimizedBufferizationPass> {
@@ -438,6 +534,7 @@ class OptimizedBufferizationPass
438
534
// This requires small code reordering in ElementalAssignBufferization.
439
535
patterns.insert <ElementalAssignBufferization>(context);
440
536
patterns.insert <BroadcastAssignBufferization>(context);
537
+ patterns.insert <VariableAssignBufferization>(context);
441
538
442
539
if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
443
540
func, std::move (patterns), config))) {
0 commit comments