Skip to content

Commit 51b65d0

Browse files
committed
[mlir][AMDGPU] Improve BF16 handling through AMDGPU compilation
Many previous sets of AMDGPU dialect code have been incorrect in the presence of the bf16 type (when lowered to LLVM's bfloat) as they were developed in a setting that run a custom bf16-to-i16 pass before LLVM lowering. An overall effect of this patch is that you should run --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" on your GPU module before calling --convert-gpu-to-rocdl if your code performs bf16 arithmetic. While LLVM now supports software bfloat, initial experiments showed that using this support on AMDGPU inserted a large number of conversions around loads and stores which had substantial performance imparts. Furthermore, all of the native AMDGPU operations on bf16 types (like the WMMA operations) operate on 16-bit integers instead of the bfloat type. First, we make the following changes to preserve compatibility once the LLVM bfloat type is reenabled. 1. The matrix multiplication operations (MFMA and WMMA) will bitcast bfloat vectors to i16 vectors. 2. Buffer loads and stores will operate on the relevant integer datatype and then cast to bfloat if needed. Second, we add type conversions to convert bf16 and vectors of it to equivalent i16 types. Third, we add the bfloat <-> f32 expansion patterns to the set of operations run before the main LLVM conversion so that MLIR's implementation of these conversion routines is used. Finally, we extend the "floats treated as integers" support in the LLVM exporter to handle types other than fp8. We also fix a bug in the unsupported floats emulation where it tried to operate on `arith.bitcast` due to an oversight. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D156361
1 parent 6aab000 commit 51b65d0

File tree

10 files changed

+102
-25
lines changed

10 files changed

+102
-25
lines changed

mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class Pass;
2121
#define GEN_PASS_DECL_CONVERTAMDGPUTOROCDL
2222
#include "mlir/Conversion/Passes.h.inc"
2323

24+
/// Note: The ROCDL target does not support the LLVM bfloat type at this time
25+
/// and so this function will add conversions to change all `bfloat` uses
26+
/// to `i16`.
2427
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
2528
RewritePatternSet &patterns,
2629
amdgpu::Chipset chipset);

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1515
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
16+
#include "mlir/IR/BuiltinTypes.h"
17+
#include "mlir/IR/TypeUtilities.h"
1618
#include "mlir/Pass/Pass.h"
1719

1820
#include "llvm/ADT/STLExtras.h"
@@ -88,8 +90,15 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
8890
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
8991
// and the total load size is >= 32, use a vector load of N / (bitsize(T) /
9092
// 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
91-
// so bitcast any floats to integers.
93+
// so bitcast any floats to integers. On top of all this, cast bfloat
94+
// (vectors) to i16 since the backend doesn't currently support bfloat on
95+
// these operations.
9296
Type llvmBufferValType = llvmWantedDataType;
97+
if (wantedDataType.isBF16())
98+
llvmBufferValType = rewriter.getI16Type();
99+
if (auto wantedVecType = dyn_cast<VectorType>(wantedDataType))
100+
if (wantedVecType.getElementType().isBF16())
101+
llvmBufferValType = wantedVecType.clone(rewriter.getI16Type());
93102
if (atomicCmpData) {
94103
if (isa<VectorType>(wantedDataType))
95104
return gpuOp.emitOpError("vector compare-and-swap does not exist");
@@ -315,10 +324,17 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
315324
/// around a wart in the AMDGPU intrinsics where operations that logically take
316325
/// vectors of bytes instead integers. Since we do not want to expose this
317326
/// implementation detail to MLIR, we correct for it here.
327+
///
328+
/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
329+
/// MFMA intrinsics pre-date the bfloat type.
318330
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
319331
Location loc, Value input) {
320332
Type inputType = input.getType();
321333
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
334+
if (vectorType.getElementType().isBF16())
335+
return rewriter.create<LLVM::BitcastOp>(
336+
loc, vectorType.clone(rewriter.getI16Type()), input);
337+
322338
if (!vectorType.getElementType().isInteger(8))
323339
return input;
324340
int64_t numBytes = vectorType.getNumElements();
@@ -343,7 +359,8 @@ static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
343359
/// Push an input operand. If it is a float type, nothing to do. If it is
344360
/// an integer type, then we need to also push its signdness (1 for signed, 0
345361
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
346-
/// vector.
362+
/// vector. We also need to convert bfloat inputs to i16 to account for the lack
363+
/// of bfloat support in the WMMA intrinsics themselves.
347364
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
348365
Location loc,
349366
const TypeConverter *typeConverter,
@@ -353,6 +370,9 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
353370
auto vectorType = inputType.dyn_cast<VectorType>();
354371
Type elemType = vectorType.getElementType();
355372

373+
if (elemType.isBF16())
374+
llvmInput = rewriter.create<LLVM::BitcastOp>(
375+
loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
356376
if (!elemType.isInteger(8)) {
357377
operands.push_back(llvmInput);
358378
return;
@@ -392,8 +412,11 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
392412
Type inputType = output.getType();
393413
auto vectorType = inputType.dyn_cast<VectorType>();
394414
Type elemType = vectorType.getElementType();
415+
if (elemType.isBF16())
416+
output = rewriter.create<LLVM::BitcastOp>(
417+
loc, vectorType.clone(rewriter.getI16Type()), output);
395418
operands.push_back(output);
396-
if (elemType.isF16() || elemType.isBF16()) {
419+
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
397420
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
398421
} else if (elemType.isInteger(32)) {
399422
operands.push_back(createI1Constant(rewriter, loc, clamp));
@@ -574,6 +597,10 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
574597
ConversionPatternRewriter &rewriter) const override {
575598
Location loc = op.getLoc();
576599
Type outType = typeConverter->convertType(op.getDestD().getType());
600+
Type intrinsicOutType = outType;
601+
if (auto outVecType = dyn_cast<VectorType>(outType))
602+
if (outVecType.getElementType().isBF16())
603+
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
577604

578605
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
579606
return op->emitOpError("MFMA only supported on gfx908+");
@@ -588,15 +615,17 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
588615
if (!maybeIntrinsic.has_value())
589616
return op.emitOpError("no intrinsic matching MFMA size on given chipset");
590617
OperationState loweredOp(loc, *maybeIntrinsic);
591-
loweredOp.addTypes(outType);
618+
loweredOp.addTypes(intrinsicOutType);
592619
loweredOp.addOperands(
593620
{mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
594621
mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
595622
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
596623
createI32Constant(rewriter, loc, op.getAbid()),
597624
createI32Constant(rewriter, loc, getBlgpField)});
598-
Operation *lowered = rewriter.create(loweredOp);
599-
rewriter.replaceOp(op, lowered->getResults());
625+
Value lowered = rewriter.create(loweredOp)->getResult(0);
626+
if (outType != intrinsicOutType)
627+
lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
628+
rewriter.replaceOp(op, lowered);
600629
return success();
601630
}
602631
};
@@ -669,6 +698,15 @@ struct ConvertAMDGPUToROCDLPass
669698
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
670699
RewritePatternSet &patterns,
671700
Chipset chipset) {
701+
converter.addConversion([](BFloat16Type t) -> Type {
702+
return IntegerType::get(t.getContext(), 16);
703+
});
704+
converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
705+
if (!t.getElementType().isBF16())
706+
return std::nullopt;
707+
return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
708+
});
709+
672710
patterns.add<LDSBarrierOpLowering>(converter);
673711
patterns.add<
674712
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,

mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
1111

1212
LINK_LIBS PUBLIC
1313
MLIRArithToLLVM
14+
MLIRArithTransforms
15+
MLIRMathToLLVM
1416
MLIRAMDGPUToROCDL
1517
MLIRFuncToLLVM
1618
MLIRGPUDialect

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1515
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
16+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Pass/PassManager.h"
19+
#include "mlir/Transforms/Passes.h"
1620

1721
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1822
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
@@ -163,6 +167,7 @@ struct LowerGpuOpsToROCDLOpsPass
163167
{
164168
RewritePatternSet patterns(ctx);
165169
populateGpuRewritePatterns(patterns);
170+
arith::populateExpandBFloat16Patterns(patterns);
166171
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
167172
}
168173

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
136136
vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
137137
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
138138
[&](Operation *op) { return converter.isLegal(op); });
139-
target.addLegalOp<arith::ExtFOp, arith::TruncFOp, arith::ConstantOp,
140-
vector::SplatOp>();
139+
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
140+
arith::ConstantOp, vector::SplatOp>();
141141
}
142142

143143
void EmulateUnsupportedFloatsPass::runOnOperation() {

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,11 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
363363
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
364364
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
365365
// Special case for 8-bit floats, which are represented by integers due to
366-
// the lack of native fp8 types in LLVM at the moment.
367-
if (APFloat::getSizeInBits(sem) == 8 && llvmType->isIntegerTy(8))
366+
// the lack of native fp8 types in LLVM at the moment. Additionally, handle
367+
// targets (like AMDGPU) that don't implement bfloat and convert all bfloats
368+
// to i16.
369+
unsigned floatWidth = APFloat::getSizeInBits(sem);
370+
if (llvmType->isIntegerTy(floatWidth))
368371
return llvm::ConstantInt::get(llvmType,
369372
floatAttr.getValue().bitcastToAPInt());
370373
if (llvmType !=

mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,25 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
3838
amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
3939
// CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
4040
amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
41-
// CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
41+
// CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
4242
amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32>
43-
// CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
43+
// CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
4444
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
45-
// CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
45+
// CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
4646
amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
47-
// CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
47+
// CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
4848
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
49-
// CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
49+
// CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
5050
amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
51-
// CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
51+
// CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
5252
amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32>
53-
// CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
53+
// CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
5454
amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
55-
// CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
55+
// CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
5656
amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
57-
// CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
57+
// CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
5858
amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
59-
// CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
59+
// CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
6060
amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
6161
// CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64>
6262
amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, f64, vector<4xf64>

mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@
22
func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
33
%arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
44
%arg6 : vector<16xi8>, %arg7 : vector<4xi32>, %arg8 : vector<8xi32>,
5-
%arg9 : vector<16xui8>){
5+
%arg9 : vector<16xui8>) {
66
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
77
amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
88
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
99
amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
10-
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xf32>) -> vector<8xf32>
10+
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
1111
amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
12-
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<4xf32>) -> vector<4xf32>
12+
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
1313
amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
1414
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
1515
amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
1616
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
1717
amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
18-
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<16xbf16>, i1) -> vector<16xbf16>
18+
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
1919
amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
20-
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>, i1) -> vector<8xbf16>
20+
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
2121
amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
2222
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
2323
amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32>

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,29 @@ gpu.module @test_module {
480480

481481
// -----
482482

483+
// Test that the bf16 type is lowered away on this target.
484+
485+
gpu.module @test_module {
486+
// CHECK-LABEL: func @bf16_id
487+
func.func @bf16_id(%arg0 : bf16) -> bf16 {
488+
// CHECK-SAME: (%[[ARG0:.+]]: i16)
489+
// CHECK-SAME: -> i16
490+
// CHECK: return %[[ARG0]] : i16
491+
func.return %arg0 : bf16
492+
}
493+
494+
// CHECK-LABEL: func @bf16x4_id
495+
func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> {
496+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>)
497+
// CHECK-SAME: -> vector<4xi16>
498+
// CHECK: return %[[ARG0]] : vector<4xi16>
499+
func.return %arg0 : vector<4xbf16>
500+
}
501+
502+
}
503+
504+
// -----
505+
483506
gpu.module @test_module {
484507
// CHECK-LABEL: @kernel_func
485508
// CHECK: attributes

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ llvm.mlir.global internal @f8E5M2FNUZ_global_as_i8(1.5 : f8E5M2FNUZ) : i8
5454
// CHECK: @f8E4M3B11FNUZ_global_as_i8 = internal global i8 92
5555
llvm.mlir.global internal @f8E4M3B11FNUZ_global_as_i8(1.5 : f8E4M3B11FNUZ) : i8
5656

57+
// CHECK: @bf16_global_as_i16 = internal global i16 16320
58+
llvm.mlir.global internal @bf16_global_as_i16(1.5 : bf16) : i16
59+
5760
// CHECK: @explicit_undef = global i32 undef
5861
llvm.mlir.global external @explicit_undef() : i32 {
5962
%0 = llvm.mlir.undef : i32

0 commit comments

Comments
 (0)