Skip to content

Commit de40f61

Browse files
authored
[flang][cuda][NFC] Use NVVM op for match all (#134303)
1 parent 0f696c2 commit de40f61

File tree

2 files changed

+19
-29
lines changed

2 files changed

+19
-29
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6478,31 +6478,23 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
64786478
assert(args.size() == 3);
64796479
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
64806480

6481-
llvm::StringRef funcName =
6482-
is32 ? "llvm.nvvm.match.all.sync.i32p" : "llvm.nvvm.match.all.sync.i64p";
6483-
mlir::MLIRContext *context = builder.getContext();
6484-
mlir::Type i32Ty = builder.getI32Type();
6485-
mlir::Type i64Ty = builder.getI64Type();
64866481
mlir::Type i1Ty = builder.getI1Type();
6487-
mlir::Type retTy = mlir::TupleType::get(context, {resultType, i1Ty});
6488-
mlir::Type valTy = is32 ? i32Ty : i64Ty;
6482+
mlir::MLIRContext *context = builder.getContext();
64896483

6490-
mlir::FunctionType ftype =
6491-
mlir::FunctionType::get(context, {i32Ty, valTy}, {retTy});
6492-
auto funcOp = builder.createFunction(loc, funcName, ftype);
6493-
llvm::SmallVector<mlir::Value> filteredArgs;
6494-
filteredArgs.push_back(args[0]);
6495-
if (args[1].getType().isF32() || args[1].getType().isF64())
6496-
filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
6497-
else
6498-
filteredArgs.push_back(args[1]);
6499-
auto call = builder.create<fir::CallOp>(loc, funcOp, filteredArgs);
6500-
auto zero = builder.getIntegerAttr(builder.getIndexType(), 0);
6501-
auto value = builder.create<fir::ExtractValueOp>(
6502-
loc, resultType, call.getResult(0), builder.getArrayAttr(zero));
6503-
auto one = builder.getIntegerAttr(builder.getIndexType(), 1);
6504-
auto pred = builder.create<fir::ExtractValueOp>(loc, i1Ty, call.getResult(0),
6505-
builder.getArrayAttr(one));
6484+
mlir::Value arg1 = args[1];
6485+
if (arg1.getType().isF32() || arg1.getType().isF64())
6486+
arg1 = builder.create<fir::ConvertOp>(
6487+
loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
6488+
6489+
mlir::Type retTy =
6490+
mlir::LLVM::LLVMStructType::getLiteral(context, {resultType, i1Ty});
6491+
auto match =
6492+
builder
6493+
.create<mlir::NVVM::MatchSyncOp>(loc, retTy, args[0], arg1,
6494+
mlir::NVVM::MatchSyncKind::all)
6495+
.getResult();
6496+
auto value = builder.create<mlir::LLVM::ExtractValueOp>(loc, match, 0);
6497+
auto pred = builder.create<mlir::LLVM::ExtractValueOp>(loc, match, 1);
65066498
auto conv = builder.create<mlir::LLVM::ZExtOp>(loc, resultType, pred);
65076499
builder.create<fir::StoreOp>(loc, conv, args[2]);
65086500
return value;

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,10 @@ attributes(device) subroutine testMatch()
124124
end subroutine
125125

126126
! CHECK-LABEL: func.func @_QPtestmatch()
127-
! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
128-
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
129-
! CHECK: fir.convert %{{.*}} : (f32) -> i32
130-
! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
131-
! CHECK: fir.convert %{{.*}} : (f64) -> i64
132-
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
127+
! CHECK: %{{.*}} = nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
128+
! CHECK: %{{.*}} = nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
129+
! CHECK: %{{.*}} = nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
130+
! CHECK: %{{.*}} = nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
133131

134132
attributes(device) subroutine testMatchAny()
135133
integer :: a, mask, v32

0 commit comments

Comments
 (0)