Skip to content

Commit 96130b5

Browse files
committed
[mlir][spirv] Support size-1 vector/tensor constant during conversion
Reviewed By: ThomasRaoux, mravishankar Differential Revision: https://reviews.llvm.org/D115518
1 parent b5c49b6 commit 96130b5

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
273273
arith::ConstantOp constOp, OpAdaptor adaptor,
274274
ConversionPatternRewriter &rewriter) const {
275275
auto srcType = constOp.getType().dyn_cast<ShapedType>();
276-
if (!srcType)
276+
if (!srcType || srcType.getNumElements() == 1)
277277
return failure();
278278

279279
// arith.constant should only have vector or tenor types.
@@ -358,16 +358,25 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
358358
arith::ConstantOp constOp, OpAdaptor adaptor,
359359
ConversionPatternRewriter &rewriter) const {
360360
Type srcType = constOp.getType();
361+
if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
362+
if (shapedType.getNumElements() != 1)
363+
return failure();
364+
srcType = shapedType.getElementType();
365+
}
361366
if (!srcType.isIntOrIndexOrFloat())
362367
return failure();
363368

369+
Attribute cstAttr = constOp.getValue();
370+
if (cstAttr.getType().isa<ShapedType>())
371+
cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
372+
364373
Type dstType = getTypeConverter()->convertType(srcType);
365374
if (!dstType)
366375
return failure();
367376

368377
// Floating-point types.
369378
if (srcType.isa<FloatType>()) {
370-
auto srcAttr = constOp.getValue().cast<FloatAttr>();
379+
auto srcAttr = cstAttr.cast<FloatAttr>();
371380
auto dstAttr = srcAttr;
372381

373382
// Floating-point types not supported in the target environment are all
@@ -386,7 +395,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
386395
if (srcType.isInteger(1)) {
387396
// arith.constant can use 0/1 instead of true/false for i1 values. We need
388397
// to handle that here.
389-
auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter);
398+
auto dstAttr = convertBoolAttr(cstAttr, rewriter);
390399
if (!dstAttr)
391400
return failure();
392401
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
@@ -395,7 +404,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
395404

396405
// IndexType or IntegerType. Index values are converted to 32-bit integer
397406
// values when converting to SPIR-V.
398-
auto srcAttr = constOp.getValue().cast<IntegerAttr>();
407+
auto srcAttr = cstAttr.cast<IntegerAttr>();
399408
auto dstAttr =
400409
convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
401410
if (!dstAttr)

mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,17 @@ func @constant_64bit() {
446446
return
447447
}
448448

449+
// CHECK-LABEL: @constant_size1
450+
func @constant_size1() {
451+
// CHECK: spv.Constant true
452+
%0 = arith.constant dense<true> : tensor<1xi1>
453+
// CHECK: spv.Constant 4 : i64
454+
%1 = arith.constant dense<4> : vector<1xi64>
455+
// CHECK: spv.Constant 5.000000e+00 : f64
456+
%2 = arith.constant dense<5.0> : tensor<1xf64>
457+
return
458+
}
459+
449460
} // end module
450461

451462
// -----
@@ -485,6 +496,15 @@ func @constant_64bit() {
485496
return
486497
}
487498

499+
// CHECK-LABEL: @constant_size1
500+
func @constant_size1() {
501+
// CHECK: spv.Constant 4 : i32
502+
%0 = arith.constant dense<4> : vector<1xi64>
503+
// CHECK: spv.Constant 5.000000e+00 : f32
504+
%1 = arith.constant dense<5.0> : tensor<1xf64>
505+
return
506+
}
507+
488508
// CHECK-LABEL: @corner_cases
489509
func @corner_cases() {
490510
// CHECK: %{{.*}} = spv.Constant -1 : i32

0 commit comments

Comments
 (0)