@@ -81,6 +81,15 @@ static bool hasDoubleDescriptors(OpTy op) {
81
81
return false ;
82
82
}
83
83
84
+ bool isDeviceGlobal (fir::GlobalOp op) {
85
+ auto attr = op.getDataAttr ();
86
+ if (attr && (*attr == cuf::DataAttribute::Device ||
87
+ *attr == cuf::DataAttribute::Managed ||
88
+ *attr == cuf::DataAttribute::Constant))
89
+ return true ;
90
+ return false ;
91
+ }
92
+
84
93
static mlir::Value createConvertOp (mlir::PatternRewriter &rewriter,
85
94
mlir::Location loc, mlir::Type toTy,
86
95
mlir::Value val) {
@@ -89,62 +98,6 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
89
98
return val;
90
99
}
91
100
92
- mlir::Value getDeviceAddress (mlir::PatternRewriter &rewriter,
93
- mlir::OpOperand &operand,
94
- const mlir::SymbolTable &symtab) {
95
- mlir::Value v = operand.get ();
96
- auto declareOp = v.getDefiningOp <fir::DeclareOp>();
97
- if (!declareOp)
98
- return v;
99
-
100
- auto addrOfOp = declareOp.getMemref ().getDefiningOp <fir::AddrOfOp>();
101
- if (!addrOfOp)
102
- return v;
103
-
104
- auto globalOp = symtab.lookup <fir::GlobalOp>(
105
- addrOfOp.getSymbol ().getRootReference ().getValue ());
106
-
107
- if (!globalOp)
108
- return v;
109
-
110
- bool isDevGlobal{false };
111
- auto attr = globalOp.getDataAttrAttr ();
112
- if (attr) {
113
- switch (attr.getValue ()) {
114
- case cuf::DataAttribute::Device:
115
- case cuf::DataAttribute::Managed:
116
- case cuf::DataAttribute::Constant:
117
- isDevGlobal = true ;
118
- break ;
119
- default :
120
- break ;
121
- }
122
- }
123
- if (!isDevGlobal)
124
- return v;
125
- mlir::OpBuilder::InsertionGuard guard (rewriter);
126
- rewriter.setInsertionPoint (operand.getOwner ());
127
- auto loc = declareOp.getLoc ();
128
- auto mod = declareOp->getParentOfType <mlir::ModuleOp>();
129
- fir::FirOpBuilder builder (rewriter, mod);
130
-
131
- mlir::func::FuncOp callee =
132
- fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
133
- auto fTy = callee.getFunctionType ();
134
- auto toTy = fTy .getInput (0 );
135
- mlir::Value inputArg =
136
- createConvertOp (rewriter, loc, toTy, declareOp.getResult ());
137
- mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
138
- mlir::Value sourceLine =
139
- fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
140
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
141
- builder, loc, fTy , inputArg, sourceFile, sourceLine)};
142
- auto call = rewriter.create <fir::CallOp>(loc, callee, args);
143
- mlir::Value cast = createConvertOp (
144
- rewriter, loc, declareOp.getMemref ().getType (), call->getResult (0 ));
145
- return cast;
146
- }
147
-
148
101
template <typename OpTy>
149
102
static mlir::LogicalResult convertOpToCall (OpTy op,
150
103
mlir::PatternRewriter &rewriter,
@@ -422,6 +375,54 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
422
375
const fir::LLVMTypeConverter *typeConverter;
423
376
};
424
377
378
+ struct DeclareOpConversion : public mlir ::OpRewritePattern<fir::DeclareOp> {
379
+ using OpRewritePattern::OpRewritePattern;
380
+
381
+ DeclareOpConversion (mlir::MLIRContext *context,
382
+ const mlir::SymbolTable &symtab)
383
+ : OpRewritePattern(context), symTab{symtab} {}
384
+
385
+ mlir::LogicalResult
386
+ matchAndRewrite (fir::DeclareOp op,
387
+ mlir::PatternRewriter &rewriter) const override {
388
+ if (auto addrOfOp = op.getMemref ().getDefiningOp <fir::AddrOfOp>()) {
389
+ if (auto global = symTab.lookup <fir::GlobalOp>(
390
+ addrOfOp.getSymbol ().getRootReference ().getValue ())) {
391
+ if (isDeviceGlobal (global)) {
392
+ rewriter.setInsertionPointAfter (addrOfOp);
393
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
394
+ fir::FirOpBuilder builder (rewriter, mod);
395
+ mlir::Location loc = op.getLoc ();
396
+ mlir::func::FuncOp callee =
397
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(
398
+ loc, builder);
399
+ auto fTy = callee.getFunctionType ();
400
+ mlir::Type toTy = fTy .getInput (0 );
401
+ mlir::Value inputArg =
402
+ createConvertOp (rewriter, loc, toTy, addrOfOp.getResult ());
403
+ mlir::Value sourceFile =
404
+ fir::factory::locationToFilename (builder, loc);
405
+ mlir::Value sourceLine =
406
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
407
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
408
+ builder, loc, fTy , inputArg, sourceFile, sourceLine)};
409
+ auto call = rewriter.create <fir::CallOp>(loc, callee, args);
410
+ mlir::Value cast = createConvertOp (
411
+ rewriter, loc, op.getMemref ().getType (), call->getResult (0 ));
412
+ rewriter.startOpModification (op);
413
+ op.getMemrefMutable ().assign (cast);
414
+ rewriter.finalizeOpModification (op);
415
+ return success ();
416
+ }
417
+ }
418
+ }
419
+ return failure ();
420
+ }
421
+
422
+ private:
423
+ const mlir::SymbolTable &symTab;
424
+ };
425
+
425
426
struct CUFFreeOpConversion : public mlir ::OpRewritePattern<cuf::FreeOp> {
426
427
using OpRewritePattern::OpRewritePattern;
427
428
@@ -511,7 +512,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
511
512
builder.create <fir::StoreOp>(loc, src, alloc);
512
513
addr = alloc;
513
514
} else {
514
- addr = getDeviceAddress (rewriter, op.getSrcMutable (), symtab );
515
+ addr = op.getSrc ( );
515
516
}
516
517
llvm::SmallVector<mlir::Value> lenParams;
517
518
mlir::Type boxTy = fir::BoxType::get (srcTy);
@@ -531,7 +532,7 @@ static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
531
532
mlir::Location loc = op.getLoc ();
532
533
fir::FirOpBuilder builder (rewriter, mod);
533
534
mlir::Type dstTy = fir::unwrapRefType (op.getDst ().getType ());
534
- mlir::Value dstAddr = getDeviceAddress (rewriter, op.getDstMutable (), symtab );
535
+ mlir::Value dstAddr = op.getDst ( );
535
536
mlir::Type dstBoxTy = fir::BoxType::get (dstTy);
536
537
llvm::SmallVector<mlir::Value> lenParams;
537
538
mlir::Value dstBox =
@@ -652,8 +653,8 @@ struct CUFDataTransferOpConversion
652
653
mlir::Value sourceLine =
653
654
fir::factory::locationToLineNo (builder, loc, fTy .getInput (5 ));
654
655
655
- mlir::Value dst = getDeviceAddress (rewriter, op.getDstMutable (), symtab );
656
- mlir::Value src = getDeviceAddress (rewriter, op.getSrcMutable (), symtab );
656
+ mlir::Value dst = op.getDst ( );
657
+ mlir::Value src = op.getSrc ( );
657
658
// Materialize the src if constant.
658
659
if (matchPattern (src.getDefiningOp (), mlir::m_Constant ())) {
659
660
mlir::Value temp = builder.createTemporary (loc, srcTy);
@@ -823,6 +824,30 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
823
824
" error in CUF op conversion\n " );
824
825
signalPassFailure ();
825
826
}
827
+
828
+ target.addDynamicallyLegalOp <fir::DeclareOp>([&](fir::DeclareOp op) {
829
+ if (inDeviceContext (op))
830
+ return true ;
831
+ if (auto addrOfOp = op.getMemref ().getDefiningOp <fir::AddrOfOp>()) {
832
+ if (auto global = symtab.lookup <fir::GlobalOp>(
833
+ addrOfOp.getSymbol ().getRootReference ().getValue ())) {
834
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType (global.getType ())))
835
+ return true ;
836
+ if (isDeviceGlobal (global))
837
+ return false ;
838
+ }
839
+ }
840
+ return true ;
841
+ });
842
+
843
+ patterns.clear ();
844
+ cuf::populateFIRCUFConversionPatterns (symtab, patterns);
845
+ if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
846
+ std::move (patterns)))) {
847
+ mlir::emitError (mlir::UnknownLoc::get (ctx),
848
+ " error in CUF op conversion\n " );
849
+ signalPassFailure ();
850
+ }
826
851
}
827
852
};
828
853
} // namespace
@@ -837,3 +862,8 @@ void cuf::populateCUFToFIRConversionPatterns(
837
862
&dl, &converter);
838
863
patterns.insert <CUFLaunchOpConversion>(patterns.getContext (), symtab);
839
864
}
865
+
866
+ void cuf::populateFIRCUFConversionPatterns (const mlir::SymbolTable &symtab,
867
+ mlir::RewritePatternSet &patterns) {
868
+ patterns.insert <DeclareOpConversion>(patterns.getContext (), symtab);
869
+ }
0 commit comments