Skip to content

Commit 5c09dda

Browse files
authored
[mlir][emitc] Lower arith.index_cast, arith.index_castui, arith.shli, arith.shrui, arith.shrsi (#95795)
This PR makes use of the newly introduced EmitC types, and lowers: * ops dealing with index types (index_cast, index_castui), * ops where `size_t` is used as part of the lowering (shli, shrui, shrsi, to check for overflow and avoid UB in this case).
1 parent 6c84bba commit 5c09dda

File tree

4 files changed

+339
-20
lines changed

4 files changed

+339
-20
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

+125-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/EmitC/IR/EmitC.h"
18+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1819
#include "mlir/IR/BuiltinAttributes.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/Transforms/DialectConversion.h"
@@ -35,8 +36,11 @@ class ArithConstantOpConversionPattern
3536
matchAndRewrite(arith::ConstantOp arithConst,
3637
arith::ConstantOp::Adaptor adaptor,
3738
ConversionPatternRewriter &rewriter) const override {
38-
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
39-
arithConst, arithConst.getType(), adaptor.getValue());
39+
Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
40+
if (!newTy)
41+
return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
42+
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
43+
adaptor.getValue());
4044
return success();
4145
}
4246
};
@@ -51,6 +55,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
5155
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
5256
signedness);
5357
}
58+
} else if (emitc::isPointerWideType(ty)) {
59+
if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
60+
if (needsUnsigned)
61+
return emitc::SizeTType::get(ty.getContext());
62+
return emitc::PtrDiffTType::get(ty.getContext());
63+
}
5464
}
5565
return ty;
5666
}
@@ -263,8 +273,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
263273
ConversionPatternRewriter &rewriter) const override {
264274

265275
Type type = adaptor.getLhs().getType();
266-
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
267-
return rewriter.notifyMatchFailure(op, "expected integer or index type");
276+
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
277+
return rewriter.notifyMatchFailure(
278+
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
268279
}
269280

270281
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
@@ -317,17 +328,21 @@ class CastConversion : public OpConversionPattern<ArithOp> {
317328
ConversionPatternRewriter &rewriter) const override {
318329

319330
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
320-
if (!isa_and_nonnull<IntegerType>(opReturnType))
321-
return rewriter.notifyMatchFailure(op, "expected integer result type");
331+
if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
332+
emitc::isPointerWideType(opReturnType)))
333+
return rewriter.notifyMatchFailure(
334+
op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
322335

323336
if (adaptor.getOperands().size() != 1) {
324337
return rewriter.notifyMatchFailure(
325338
op, "CastConversion only supports unary ops");
326339
}
327340

328341
Type operandType = adaptor.getIn().getType();
329-
if (!isa_and_nonnull<IntegerType>(operandType))
330-
return rewriter.notifyMatchFailure(op, "expected integer operand type");
342+
if (!operandType || !(isa<IntegerType>(operandType) ||
343+
emitc::isPointerWideType(operandType)))
344+
return rewriter.notifyMatchFailure(
345+
op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
331346

332347
// Signed (sign-extending) casts from i1 are not supported.
333348
if (operandType.isInteger(1) && !castToUnsigned)
@@ -338,8 +353,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
338353
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
339354
// truncation.
340355
if (opReturnType.isInteger(1)) {
356+
Type attrType = (emitc::isPointerWideType(operandType))
357+
? rewriter.getIndexType()
358+
: operandType;
341359
auto constOne = rewriter.create<emitc::ConstantOp>(
342-
op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
360+
op.getLoc(), operandType, rewriter.getOneAttr(attrType));
343361
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
344362
op.getLoc(), operandType, adaptor.getIn(), constOne);
345363
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
@@ -392,7 +410,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
392410
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
393411
ConversionPatternRewriter &rewriter) const override {
394412

395-
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
413+
Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
414+
if (!newTy)
415+
return rewriter.notifyMatchFailure(arithOp,
416+
"converting result type failed");
417+
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
396418
adaptor.getOperands());
397419

398420
return success();
@@ -409,8 +431,9 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
409431
ConversionPatternRewriter &rewriter) const override {
410432

411433
Type type = this->getTypeConverter()->convertType(op.getType());
412-
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
413-
return rewriter.notifyMatchFailure(op, "expected integer type");
434+
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
435+
return rewriter.notifyMatchFailure(
436+
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
414437
}
415438

416439
if (type.isInteger(1)) {
@@ -481,6 +504,89 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
481504
}
482505
};
483506

507+
template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
508+
class ShiftOpConversion : public OpConversionPattern<ArithOp> {
509+
public:
510+
using OpConversionPattern<ArithOp>::OpConversionPattern;
511+
512+
LogicalResult
513+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
514+
ConversionPatternRewriter &rewriter) const override {
515+
516+
Type type = this->getTypeConverter()->convertType(op.getType());
517+
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
518+
return rewriter.notifyMatchFailure(
519+
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
520+
}
521+
522+
if (type.isInteger(1)) {
523+
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
524+
}
525+
526+
Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
527+
528+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
529+
// Shift amount interpreted as unsigned per Arith dialect spec.
530+
Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
531+
/*needsUnsigned=*/true);
532+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
533+
534+
// Add a runtime check for overflow
535+
Value width;
536+
if (emitc::isPointerWideType(type)) {
537+
Value eight = rewriter.create<emitc::ConstantOp>(
538+
op.getLoc(), rhsType, rewriter.getIndexAttr(8));
539+
emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
540+
op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
541+
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
542+
sizeOfCall.getResult(0));
543+
} else {
544+
width = rewriter.create<emitc::ConstantOp>(
545+
op.getLoc(), rhsType,
546+
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
547+
}
548+
549+
Value excessCheck = rewriter.create<emitc::CmpOp>(
550+
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
551+
552+
// Any concrete value is a valid refinement of poison.
553+
Value poison = rewriter.create<emitc::ConstantOp>(
554+
op.getLoc(), arithmeticType,
555+
(isa<IntegerType>(arithmeticType)
556+
? rewriter.getIntegerAttr(arithmeticType, 0)
557+
: rewriter.getIndexAttr(0)));
558+
559+
emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
560+
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
561+
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
562+
auto currentPoint = rewriter.getInsertionPoint();
563+
rewriter.setInsertionPointToStart(&bodyBlock);
564+
Value arithmeticResult =
565+
rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
566+
Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
567+
op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
568+
rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
569+
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
570+
571+
Value result = adaptValueType(ternary, rewriter, type);
572+
573+
rewriter.replaceOp(op, result);
574+
return success();
575+
}
576+
};
577+
578+
template <typename ArithOp, typename EmitCOp>
579+
class SignedShiftOpConversion final
580+
: public ShiftOpConversion<ArithOp, EmitCOp, false> {
581+
using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
582+
};
583+
584+
template <typename ArithOp, typename EmitCOp>
585+
class UnsignedShiftOpConversion final
586+
: public ShiftOpConversion<ArithOp, EmitCOp, true> {
587+
using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
588+
};
589+
484590
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
485591
public:
486592
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -605,6 +711,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
605711
RewritePatternSet &patterns) {
606712
MLIRContext *ctx = patterns.getContext();
607713

714+
mlir::populateEmitCSizeTTypeConversions(typeConverter);
715+
608716
// clang-format off
609717
patterns.add<
610718
ArithConstantOpConversionPattern,
@@ -620,6 +728,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
620728
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
621729
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
622730
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
731+
UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
732+
SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
733+
UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
623734
CmpFOpConversion,
624735
CmpIOpConversion,
625736
NegFOpConversion,
@@ -628,6 +739,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
628739
UnsignedCastConversion<arith::TruncIOp>,
629740
SignedCastConversion<arith::ExtSIOp>,
630741
UnsignedCastConversion<arith::ExtUIOp>,
742+
SignedCastConversion<arith::IndexCastOp>,
743+
UnsignedCastConversion<arith::IndexCastUIOp>,
631744
ItoFCastOpConversion<arith::SIToFPOp>,
632745
ItoFCastOpConversion<arith::UIToFPOp>,
633746
FtoICastOpConversion<arith::FPToSIOp>,

mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
1111
LINK_LIBS PUBLIC
1212
MLIRArithDialect
1313
MLIREmitCDialect
14+
MLIREmitCTransforms
1415
MLIRPass
1516
MLIRTransformUtils
1617
)

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

+24
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) {
110110
%idx = arith.extsi %arg0 : i1 to i32
111111
return
112112
}
113+
114+
// -----
115+
116+
func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
117+
// expected-error @+1 {{failed to legalize operation 'arith.shli'}}
118+
%shli = arith.shli %arg0, %arg1 : i1
119+
return
120+
}
121+
122+
// -----
123+
124+
func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
125+
// expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
126+
%shrsi = arith.shrsi %arg0, %arg1 : i1
127+
return
128+
}
129+
130+
// -----
131+
132+
func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
133+
// expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
134+
%shrui = arith.shrui %arg0, %arg1 : i1
135+
return
136+
}

0 commit comments

Comments
 (0)