Skip to content

Commit f4e3b87

Browse files
authored
[mlir][LLVM] Switch undef for poison for uninitialized values (llvm#125629)
LLVM itself is generally moving away from using `undef` and towards using `poison`, to the point of having a lint that caches new uses of `undef` in tests. In order to not trip the lint on new patterns and to conform to the evolution of LLVM - Rename valious ::undef() methods on StructBuilder subclasses to ::poison() - Audit the uses of UndefOp in the MLIR libraries and replace almost all of them with PoisonOp The remaining uses of `undef` are initializing `uninitialized` memrefs, explicit conversions to undef from SPIR-V, and a few cases in AMDGPUToROCDL where usage like %v = insertelement <M x iN> undef, iN %v, i32 0 %arg = bitcast <M x iN> %v to i(M * N) is used to handle "i32" arguments that are are really packed vectors of smaller types that won't always be fully initialized.
1 parent e41ffd3 commit f4e3b87

40 files changed

+457
-455
lines changed

mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class ComplexStructBuilder : public StructBuilder {
2424
/// Construct a helper for the given complex number value.
2525
using StructBuilder::StructBuilder;
2626
/// Build IR creating an `undef` value of the complex number type.
27-
static ComplexStructBuilder undef(OpBuilder &builder, Location loc,
28-
Type type);
27+
static ComplexStructBuilder poison(OpBuilder &builder, Location loc,
28+
Type type);
2929

3030
// Build IR extracting the real value from the complex number struct.
3131
Value real(OpBuilder &builder, Location loc);

mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ class MemRefDescriptor : public StructBuilder {
3434
public:
3535
/// Construct a helper for the given descriptor value.
3636
explicit MemRefDescriptor(Value descriptor);
37-
/// Builds IR creating an `undef` value of the descriptor type.
38-
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
39-
Type descriptorType);
37+
/// Builds IR creating a `poison` value of the descriptor type.
38+
static MemRefDescriptor poison(OpBuilder &builder, Location loc,
39+
Type descriptorType);
4040
/// Builds IR creating a MemRef descriptor that represents `type` and
4141
/// populates it with static shape and stride information extracted from the
4242
/// type.
@@ -160,8 +160,8 @@ class UnrankedMemRefDescriptor : public StructBuilder {
160160
/// Construct a helper for the given descriptor value.
161161
explicit UnrankedMemRefDescriptor(Value descriptor);
162162
/// Builds IR creating an `undef` value of the descriptor type.
163-
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
164-
Type descriptorType);
163+
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc,
164+
Type descriptorType);
165165

166166
/// Builds IR extracting the rank from the descriptor
167167
Value rank(OpBuilder &builder, Location loc) const;

mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ class StructBuilder {
2727
public:
2828
/// Construct a helper for the given value.
2929
explicit StructBuilder(Value v);
30-
/// Builds IR creating an `undef` value of the descriptor type.
31-
static StructBuilder undef(OpBuilder &builder, Location loc,
32-
Type descriptorType);
30+
/// Builds IR creating a `poison` value of the descriptor type.
31+
static StructBuilder poison(OpBuilder &builder, Location loc,
32+
Type descriptorType);
3333

3434
/*implicit*/ operator Value() { return value; }
3535

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
317317
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
318318
loc, DenseElementsAttr::get(predicateType, true));
319319
// Create padding vector (never used due to all-true predicate).
320-
auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
320+
auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType);
321321
// Get a pointer to the current slice.
322322
auto slicePtr =
323323
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);

mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ using namespace mlir::arith;
3333
static constexpr unsigned kRealPosInComplexNumberStruct = 0;
3434
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
3535

36-
ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
37-
Location loc, Type type) {
38-
Value val = builder.create<LLVM::UndefOp>(loc, type);
36+
ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
37+
Location loc, Type type) {
38+
Value val = builder.create<LLVM::PoisonOp>(loc, type);
3939
return ComplexStructBuilder(val);
4040
}
4141

@@ -109,7 +109,8 @@ struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
109109
// Pack real and imaginary part in a complex number struct.
110110
auto loc = complexOp.getLoc();
111111
auto structType = typeConverter->convertType(complexOp.getType());
112-
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
112+
auto complexStruct =
113+
ComplexStructBuilder::poison(rewriter, loc, structType);
113114
complexStruct.setReal(rewriter, loc, adaptor.getReal());
114115
complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
115116

@@ -183,7 +184,7 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
183184

184185
// Initialize complex number struct for result.
185186
auto structType = typeConverter->convertType(op.getType());
186-
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
187+
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
187188

188189
// Emit IR to add complex numbers.
189190
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
@@ -214,7 +215,7 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
214215

215216
// Initialize complex number struct for result.
216217
auto structType = typeConverter->convertType(op.getType());
217-
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
218+
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
218219

219220
// Emit IR to add complex numbers.
220221
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
@@ -262,7 +263,7 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
262263

263264
// Initialize complex number struct for result.
264265
auto structType = typeConverter->convertType(op.getType());
265-
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
266+
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
266267

267268
// Emit IR to add complex numbers.
268269
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
@@ -302,7 +303,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
302303

303304
// Initialize complex number struct for result.
304305
auto structType = typeConverter->convertType(op.getType());
305-
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
306+
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
306307

307308
// Emit IR to substract complex numbers.
308309
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ struct UnrealizedConversionCastOpLowering
660660
// `ReturnOp` interacts with the function signature and must have as many
661661
// operands as the function has return values. Because in LLVM IR, functions
662662
// can only return 0 or 1 value, we pack multiple values into a structure type.
663-
// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
663+
// Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
664664
// necessary before returning it
665665
struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
666666
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
@@ -714,7 +714,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
714714
return rewriter.notifyMatchFailure(op, "could not convert result types");
715715
}
716716

717-
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
717+
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
718718
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
719719
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
720720
}

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
603603
return rewriter.notifyMatchFailure(op, "expected vector result");
604604

605605
Location loc = op->getLoc();
606-
Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
606+
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
607607
Type indexType = converter.convertType(rewriter.getIndexType());
608608
StringAttr name = op->getName().getIdentifier();
609609
Type elementType = vectorType.getElementType();
@@ -771,7 +771,7 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite(
771771
return rewriter.notifyMatchFailure(op, "could not convert result types");
772772
}
773773

774-
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
774+
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
775775
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
776776
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
777777
}

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ struct WmmaConstantOpToNVVMLowering
279279
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
280280
// If the element type is a vector create a vector from the operand.
281281
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
282-
Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
282+
Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
283283
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
284284
Value idx = rewriter.create<LLVM::ConstantOp>(
285285
loc, rewriter.getI32Type(), vecEl);
@@ -288,7 +288,7 @@ struct WmmaConstantOpToNVVMLowering
288288
}
289289
cst = vecCst;
290290
}
291-
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
291+
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
292292
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
293293
matrixStruct =
294294
rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
@@ -355,7 +355,7 @@ struct WmmaElementwiseOpToNVVMLowering
355355
size_t numOperands = adaptor.getOperands().size();
356356
LLVM::LLVMStructType destType = convertMMAToLLVMType(
357357
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
358-
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
358+
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
359359
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
360360
SmallVector<Value> extractedOperands;
361361
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {

mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor)
2929
}
3030

3131
/// Builds IR creating an `undef` value of the descriptor type.
32-
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
33-
Type descriptorType) {
32+
MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc,
33+
Type descriptorType) {
3434

35-
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
35+
Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
3636
return MemRefDescriptor(descriptor);
3737
}
3838

@@ -60,7 +60,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
6060
auto convertedType = typeConverter.convertType(type);
6161
assert(convertedType && "unexpected failure in memref type conversion");
6262

63-
auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
63+
auto descr = MemRefDescriptor::poison(builder, loc, convertedType);
6464
descr.setAllocatedPtr(builder, loc, memory);
6565
descr.setAlignedPtr(builder, loc, alignedMemory);
6666
descr.setConstantOffset(builder, loc, offset);
@@ -224,7 +224,7 @@ Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
224224
const LLVMTypeConverter &converter,
225225
MemRefType type, ValueRange values) {
226226
Type llvmType = converter.convertType(type);
227-
auto d = MemRefDescriptor::undef(builder, loc, llvmType);
227+
auto d = MemRefDescriptor::poison(builder, loc, llvmType);
228228

229229
d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
230230
d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
@@ -300,10 +300,10 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
300300
: StructBuilder(descriptor) {}
301301

302302
/// Builds IR creating an `undef` value of the descriptor type.
303-
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
304-
Location loc,
305-
Type descriptorType) {
306-
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
303+
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder,
304+
Location loc,
305+
Type descriptorType) {
306+
Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
307307
return UnrankedMemRefDescriptor(descriptor);
308308
}
309309
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
@@ -331,7 +331,7 @@ Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
331331
UnrankedMemRefType type,
332332
ValueRange values) {
333333
Type llvmType = converter.convertType(type);
334-
auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
334+
auto d = UnrankedMemRefDescriptor::poison(builder, loc, llvmType);
335335

336336
d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
337337
d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
218218
ArrayRef<Value> sizes, ArrayRef<Value> strides,
219219
ConversionPatternRewriter &rewriter) const {
220220
auto structType = typeConverter->convertType(memRefType);
221-
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
221+
auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);
222222

223223
// Field 1: Allocated pointer, used for malloc/free.
224224
memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
@@ -319,7 +319,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
319319
if (!descriptorType)
320320
return failure();
321321
auto updatedDesc =
322-
UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
322+
UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
323323
Value rank = desc.rank(builder, loc);
324324
updatedDesc.setRank(builder, loc, rank);
325325
updatedDesc.setMemRefDescPtr(builder, loc, memory);

mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
8787
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
8888
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
8989
auto loc = op->getLoc();
90-
Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
90+
Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy);
9191
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
9292
// For this unrolled `position` corresponding to the `linearIndex`^th
9393
// element, extract operand vectors

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,10 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
714714
// rank = ConstantOp srcRank
715715
auto rankVal = rewriter.create<LLVM::ConstantOp>(
716716
loc, getIndexType(), rewriter.getIndexAttr(rank));
717-
// undef = UndefOp
717+
// poison = PoisonOp
718718
UnrankedMemRefDescriptor memRefDesc =
719-
UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
720-
// d1 = InsertValueOp undef, rank, 0
719+
UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
720+
// d1 = InsertValueOp poison, rank, 0
721721
memRefDesc.setRank(rewriter, loc, rankVal);
722722
// d2 = InsertValueOp d1, ptr, 1
723723
memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
@@ -928,7 +928,7 @@ struct MemorySpaceCastOpLowering
928928
Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
929929

930930
// Create and allocate storage for new memref descriptor.
931-
auto result = UnrankedMemRefDescriptor::undef(
931+
auto result = UnrankedMemRefDescriptor::poison(
932932
rewriter, loc, typeConverter->convertType(resultTypeU));
933933
result.setRank(rewriter, loc, rank);
934934
SmallVector<Value, 1> sizes;
@@ -1058,7 +1058,7 @@ struct MemRefReinterpretCastOpLowering
10581058

10591059
// Create descriptor.
10601060
Location loc = castOp.getLoc();
1061-
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1061+
auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
10621062

10631063
// Set allocated and aligned pointers.
10641064
Value allocatedPtr, alignedPtr;
@@ -1128,7 +1128,7 @@ struct MemRefReshapeOpLowering
11281128
// Create descriptor.
11291129
Location loc = reshapeOp.getLoc();
11301130
auto desc =
1131-
MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1131+
MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
11321132

11331133
// Set allocated and aligned pointers.
11341134
Value allocatedPtr, alignedPtr;
@@ -1210,7 +1210,7 @@ struct MemRefReshapeOpLowering
12101210

12111211
// Create the unranked memref descriptor that holds the ranked one. The
12121212
// inner descriptor is allocated on stack.
1213-
auto targetDesc = UnrankedMemRefDescriptor::undef(
1213+
auto targetDesc = UnrankedMemRefDescriptor::poison(
12141214
rewriter, loc, typeConverter->convertType(targetType));
12151215
targetDesc.setRank(rewriter, loc, resultRank);
12161216
SmallVector<Value, 4> sizes;
@@ -1366,7 +1366,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
13661366
if (transposeOp.getPermutation().isIdentity())
13671367
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
13681368

1369-
auto targetMemRef = MemRefDescriptor::undef(
1369+
auto targetMemRef = MemRefDescriptor::poison(
13701370
rewriter, loc,
13711371
typeConverter->convertType(transposeOp.getIn().getType()));
13721372

@@ -1469,7 +1469,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
14691469

14701470
// Create the descriptor.
14711471
MemRefDescriptor sourceMemRef(adaptor.getSource());
1472-
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1472+
auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
14731473

14741474
// Field 1: Copy the allocated pointer, used for malloc/free.
14751475
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);

0 commit comments

Comments
 (0)