@@ -6478,31 +6478,23 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
6478
6478
assert(args.size() == 3);
6479
6479
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
6480
6480
6481
- llvm::StringRef funcName =
6482
- is32 ? "llvm.nvvm.match.all.sync.i32p" : "llvm.nvvm.match.all.sync.i64p";
6483
- mlir::MLIRContext *context = builder.getContext();
6484
- mlir::Type i32Ty = builder.getI32Type();
6485
- mlir::Type i64Ty = builder.getI64Type();
6486
6481
mlir::Type i1Ty = builder.getI1Type();
6487
- mlir::Type retTy = mlir::TupleType::get(context, {resultType, i1Ty});
6488
- mlir::Type valTy = is32 ? i32Ty : i64Ty;
6482
+ mlir::MLIRContext *context = builder.getContext();
6489
6483
6490
- mlir::FunctionType ftype =
6491
- mlir::FunctionType::get(context, {i32Ty, valTy}, {retTy});
6492
- auto funcOp = builder.createFunction(loc, funcName, ftype);
6493
- llvm::SmallVector<mlir::Value> filteredArgs;
6494
- filteredArgs.push_back(args[0]);
6495
- if (args[1].getType().isF32() || args[1].getType().isF64())
6496
- filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
6497
- else
6498
- filteredArgs.push_back(args[1]);
6499
- auto call = builder.create<fir::CallOp>(loc, funcOp, filteredArgs);
6500
- auto zero = builder.getIntegerAttr(builder.getIndexType(), 0);
6501
- auto value = builder.create<fir::ExtractValueOp>(
6502
- loc, resultType, call.getResult(0), builder.getArrayAttr(zero));
6503
- auto one = builder.getIntegerAttr(builder.getIndexType(), 1);
6504
- auto pred = builder.create<fir::ExtractValueOp>(loc, i1Ty, call.getResult(0),
6505
- builder.getArrayAttr(one));
6484
+ mlir::Value arg1 = args[1];
6485
+ if (arg1.getType().isF32() || arg1.getType().isF64())
6486
+ arg1 = builder.create<fir::ConvertOp>(
6487
+ loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
6488
+
6489
+ mlir::Type retTy =
6490
+ mlir::LLVM::LLVMStructType::getLiteral(context, {resultType, i1Ty});
6491
+ auto match =
6492
+ builder
6493
+ .create<mlir::NVVM::MatchSyncOp>(loc, retTy, args[0], arg1,
6494
+ mlir::NVVM::MatchSyncKind::all)
6495
+ .getResult();
6496
+ auto value = builder.create<mlir::LLVM::ExtractValueOp>(loc, match, 0);
6497
+ auto pred = builder.create<mlir::LLVM::ExtractValueOp>(loc, match, 1);
6506
6498
auto conv = builder.create<mlir::LLVM::ZExtOp>(loc, resultType, pred);
6507
6499
builder.create<fir::StoreOp>(loc, conv, args[2]);
6508
6500
return value;
0 commit comments