@@ -273,7 +273,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
273
273
arith::ConstantOp constOp, OpAdaptor adaptor,
274
274
ConversionPatternRewriter &rewriter) const {
275
275
auto srcType = constOp.getType ().dyn_cast <ShapedType>();
276
- if (!srcType)
276
+ if (!srcType || srcType. getNumElements () == 1 )
277
277
return failure ();
278
278
279
279
// arith.constant should only have vector or tenor types.
@@ -358,16 +358,25 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
358
358
arith::ConstantOp constOp, OpAdaptor adaptor,
359
359
ConversionPatternRewriter &rewriter) const {
360
360
Type srcType = constOp.getType ();
361
+ if (auto shapedType = srcType.dyn_cast <ShapedType>()) {
362
+ if (shapedType.getNumElements () != 1 )
363
+ return failure ();
364
+ srcType = shapedType.getElementType ();
365
+ }
361
366
if (!srcType.isIntOrIndexOrFloat ())
362
367
return failure ();
363
368
369
+ Attribute cstAttr = constOp.getValue ();
370
+ if (cstAttr.getType ().isa <ShapedType>())
371
+ cstAttr = cstAttr.cast <DenseElementsAttr>().getSplatValue <Attribute>();
372
+
364
373
Type dstType = getTypeConverter ()->convertType (srcType);
365
374
if (!dstType)
366
375
return failure ();
367
376
368
377
// Floating-point types.
369
378
if (srcType.isa <FloatType>()) {
370
- auto srcAttr = constOp. getValue () .cast <FloatAttr>();
379
+ auto srcAttr = cstAttr .cast <FloatAttr>();
371
380
auto dstAttr = srcAttr;
372
381
373
382
// Floating-point types not supported in the target environment are all
@@ -386,7 +395,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
386
395
if (srcType.isInteger (1 )) {
387
396
// arith.constant can use 0/1 instead of true/false for i1 values. We need
388
397
// to handle that here.
389
- auto dstAttr = convertBoolAttr (constOp. getValue () , rewriter);
398
+ auto dstAttr = convertBoolAttr (cstAttr , rewriter);
390
399
if (!dstAttr)
391
400
return failure ();
392
401
rewriter.replaceOpWithNewOp <spirv::ConstantOp>(constOp, dstType, dstAttr);
@@ -395,7 +404,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
395
404
396
405
// IndexType or IntegerType. Index values are converted to 32-bit integer
397
406
// values when converting to SPIR-V.
398
- auto srcAttr = constOp. getValue () .cast <IntegerAttr>();
407
+ auto srcAttr = cstAttr .cast <IntegerAttr>();
399
408
auto dstAttr =
400
409
convertIntegerAttr (srcAttr, dstType.cast <IntegerType>(), rewriter);
401
410
if (!dstAttr)
0 commit comments