15
15
16
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
17
#include " mlir/Dialect/EmitC/IR/EmitC.h"
18
+ #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
18
19
#include " mlir/IR/BuiltinAttributes.h"
19
20
#include " mlir/IR/BuiltinTypes.h"
20
21
#include " mlir/Transforms/DialectConversion.h"
@@ -35,8 +36,11 @@ class ArithConstantOpConversionPattern
35
36
matchAndRewrite (arith::ConstantOp arithConst,
36
37
arith::ConstantOp::Adaptor adaptor,
37
38
ConversionPatternRewriter &rewriter) const override {
38
- rewriter.replaceOpWithNewOp <emitc::ConstantOp>(
39
- arithConst, arithConst.getType (), adaptor.getValue ());
39
+ Type newTy = this ->getTypeConverter ()->convertType (arithConst.getType ());
40
+ if (!newTy)
41
+ return rewriter.notifyMatchFailure (arithConst, " type conversion failed" );
42
+ rewriter.replaceOpWithNewOp <emitc::ConstantOp>(arithConst, newTy,
43
+ adaptor.getValue ());
40
44
return success ();
41
45
}
42
46
};
@@ -51,6 +55,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
51
55
return IntegerType::get (ty.getContext (), ty.getIntOrFloatBitWidth (),
52
56
signedness);
53
57
}
58
+ } else if (emitc::isPointerWideType (ty)) {
59
+ if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
60
+ if (needsUnsigned)
61
+ return emitc::SizeTType::get (ty.getContext ());
62
+ return emitc::PtrDiffTType::get (ty.getContext ());
63
+ }
54
64
}
55
65
return ty;
56
66
}
@@ -263,8 +273,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
263
273
ConversionPatternRewriter &rewriter) const override {
264
274
265
275
Type type = adaptor.getLhs ().getType ();
266
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
267
- return rewriter.notifyMatchFailure (op, " expected integer or index type" );
276
+ if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType (type))) {
277
+ return rewriter.notifyMatchFailure (
278
+ op, " expected integer or size_t/ssize_t/ptrdiff_t type" );
268
279
}
269
280
270
281
bool needsUnsigned = needsUnsignedCmp (op.getPredicate ());
@@ -317,17 +328,21 @@ class CastConversion : public OpConversionPattern<ArithOp> {
317
328
ConversionPatternRewriter &rewriter) const override {
318
329
319
330
Type opReturnType = this ->getTypeConverter ()->convertType (op.getType ());
320
- if (!isa_and_nonnull<IntegerType>(opReturnType))
321
- return rewriter.notifyMatchFailure (op, " expected integer result type" );
331
+ if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
332
+ emitc::isPointerWideType (opReturnType)))
333
+ return rewriter.notifyMatchFailure (
334
+ op, " expected integer or size_t/ssize_t/ptrdiff_t result type" );
322
335
323
336
if (adaptor.getOperands ().size () != 1 ) {
324
337
return rewriter.notifyMatchFailure (
325
338
op, " CastConversion only supports unary ops" );
326
339
}
327
340
328
341
Type operandType = adaptor.getIn ().getType ();
329
- if (!isa_and_nonnull<IntegerType>(operandType))
330
- return rewriter.notifyMatchFailure (op, " expected integer operand type" );
342
+ if (!operandType || !(isa<IntegerType>(operandType) ||
343
+ emitc::isPointerWideType (operandType)))
344
+ return rewriter.notifyMatchFailure (
345
+ op, " expected integer or size_t/ssize_t/ptrdiff_t operand type" );
331
346
332
347
// Signed (sign-extending) casts from i1 are not supported.
333
348
if (operandType.isInteger (1 ) && !castToUnsigned)
@@ -338,8 +353,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
338
353
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
339
354
// truncation.
340
355
if (opReturnType.isInteger (1 )) {
356
+ Type attrType = (emitc::isPointerWideType (operandType))
357
+ ? rewriter.getIndexType ()
358
+ : operandType;
341
359
auto constOne = rewriter.create <emitc::ConstantOp>(
342
- op.getLoc (), operandType, rewriter.getIntegerAttr (operandType, 1 ));
360
+ op.getLoc (), operandType, rewriter.getOneAttr (attrType ));
343
361
auto oneAndOperand = rewriter.create <emitc::BitwiseAndOp>(
344
362
op.getLoc (), operandType, adaptor.getIn (), constOne);
345
363
rewriter.replaceOpWithNewOp <emitc::CastOp>(op, opReturnType,
@@ -392,7 +410,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
392
410
matchAndRewrite (ArithOp arithOp, typename ArithOp::Adaptor adaptor,
393
411
ConversionPatternRewriter &rewriter) const override {
394
412
395
- rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, arithOp.getType (),
413
+ Type newTy = this ->getTypeConverter ()->convertType (arithOp.getType ());
414
+ if (!newTy)
415
+ return rewriter.notifyMatchFailure (arithOp,
416
+ " converting result type failed" );
417
+ rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, newTy,
396
418
adaptor.getOperands ());
397
419
398
420
return success ();
@@ -409,8 +431,9 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
409
431
ConversionPatternRewriter &rewriter) const override {
410
432
411
433
Type type = this ->getTypeConverter ()->convertType (op.getType ());
412
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
413
- return rewriter.notifyMatchFailure (op, " expected integer type" );
434
+ if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType (type))) {
435
+ return rewriter.notifyMatchFailure (
436
+ op, " expected integer or size_t/ssize_t/ptrdiff_t type" );
414
437
}
415
438
416
439
if (type.isInteger (1 )) {
@@ -481,6 +504,89 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
481
504
}
482
505
};
483
506
507
+ template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
508
+ class ShiftOpConversion : public OpConversionPattern <ArithOp> {
509
+ public:
510
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
511
+
512
+ LogicalResult
513
+ matchAndRewrite (ArithOp op, typename ArithOp::Adaptor adaptor,
514
+ ConversionPatternRewriter &rewriter) const override {
515
+
516
+ Type type = this ->getTypeConverter ()->convertType (op.getType ());
517
+ if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType (type))) {
518
+ return rewriter.notifyMatchFailure (
519
+ op, " expected integer or size_t/ssize_t/ptrdiff_t type" );
520
+ }
521
+
522
+ if (type.isInteger (1 )) {
523
+ return rewriter.notifyMatchFailure (op, " i1 type is not implemented" );
524
+ }
525
+
526
+ Type arithmeticType = adaptIntegralTypeSignedness (type, isUnsignedOp);
527
+
528
+ Value lhs = adaptValueType (adaptor.getLhs (), rewriter, arithmeticType);
529
+ // Shift amount interpreted as unsigned per Arith dialect spec.
530
+ Type rhsType = adaptIntegralTypeSignedness (adaptor.getRhs ().getType (),
531
+ /* needsUnsigned=*/ true );
532
+ Value rhs = adaptValueType (adaptor.getRhs (), rewriter, rhsType);
533
+
534
+ // Add a runtime check for overflow
535
+ Value width;
536
+ if (emitc::isPointerWideType (type)) {
537
+ Value eight = rewriter.create <emitc::ConstantOp>(
538
+ op.getLoc (), rhsType, rewriter.getIndexAttr (8 ));
539
+ emitc::CallOpaqueOp sizeOfCall = rewriter.create <emitc::CallOpaqueOp>(
540
+ op.getLoc (), rhsType, " sizeof" , ArrayRef<Value>{eight});
541
+ width = rewriter.create <emitc::MulOp>(op.getLoc (), rhsType, eight,
542
+ sizeOfCall.getResult (0 ));
543
+ } else {
544
+ width = rewriter.create <emitc::ConstantOp>(
545
+ op.getLoc (), rhsType,
546
+ rewriter.getIntegerAttr (rhsType, type.getIntOrFloatBitWidth ()));
547
+ }
548
+
549
+ Value excessCheck = rewriter.create <emitc::CmpOp>(
550
+ op.getLoc (), rewriter.getI1Type (), emitc::CmpPredicate::lt, rhs, width);
551
+
552
+ // Any concrete value is a valid refinement of poison.
553
+ Value poison = rewriter.create <emitc::ConstantOp>(
554
+ op.getLoc (), arithmeticType,
555
+ (isa<IntegerType>(arithmeticType)
556
+ ? rewriter.getIntegerAttr (arithmeticType, 0 )
557
+ : rewriter.getIndexAttr (0 )));
558
+
559
+ emitc::ExpressionOp ternary = rewriter.create <emitc::ExpressionOp>(
560
+ op.getLoc (), arithmeticType, /* do_not_inline=*/ false );
561
+ Block &bodyBlock = ternary.getBodyRegion ().emplaceBlock ();
562
+ auto currentPoint = rewriter.getInsertionPoint ();
563
+ rewriter.setInsertionPointToStart (&bodyBlock);
564
+ Value arithmeticResult =
565
+ rewriter.create <EmitCOp>(op.getLoc (), arithmeticType, lhs, rhs);
566
+ Value resultOrPoison = rewriter.create <emitc::ConditionalOp>(
567
+ op.getLoc (), arithmeticType, excessCheck, arithmeticResult, poison);
568
+ rewriter.create <emitc::YieldOp>(op.getLoc (), resultOrPoison);
569
+ rewriter.setInsertionPoint (op->getBlock (), currentPoint);
570
+
571
+ Value result = adaptValueType (ternary, rewriter, type);
572
+
573
+ rewriter.replaceOp (op, result);
574
+ return success ();
575
+ }
576
+ };
577
+
578
+ template <typename ArithOp, typename EmitCOp>
579
+ class SignedShiftOpConversion final
580
+ : public ShiftOpConversion<ArithOp, EmitCOp, false > {
581
+ using ShiftOpConversion<ArithOp, EmitCOp, false >::ShiftOpConversion;
582
+ };
583
+
584
+ template <typename ArithOp, typename EmitCOp>
585
+ class UnsignedShiftOpConversion final
586
+ : public ShiftOpConversion<ArithOp, EmitCOp, true > {
587
+ using ShiftOpConversion<ArithOp, EmitCOp, true >::ShiftOpConversion;
588
+ };
589
+
484
590
class SelectOpConversion : public OpConversionPattern <arith::SelectOp> {
485
591
public:
486
592
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -605,6 +711,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
605
711
RewritePatternSet &patterns) {
606
712
MLIRContext *ctx = patterns.getContext ();
607
713
714
+ mlir::populateEmitCSizeTTypeConversions (typeConverter);
715
+
608
716
// clang-format off
609
717
patterns.add <
610
718
ArithConstantOpConversionPattern,
@@ -620,6 +728,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
620
728
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
621
729
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
622
730
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
731
+ UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
732
+ SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
733
+ UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
623
734
CmpFOpConversion,
624
735
CmpIOpConversion,
625
736
NegFOpConversion,
@@ -628,6 +739,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
628
739
UnsignedCastConversion<arith::TruncIOp>,
629
740
SignedCastConversion<arith::ExtSIOp>,
630
741
UnsignedCastConversion<arith::ExtUIOp>,
742
+ SignedCastConversion<arith::IndexCastOp>,
743
+ UnsignedCastConversion<arith::IndexCastUIOp>,
631
744
ItoFCastOpConversion<arith::SIToFPOp>,
632
745
ItoFCastOpConversion<arith::UIToFPOp>,
633
746
FtoICastOpConversion<arith::FPToSIOp>,
0 commit comments