Skip to content

Commit 4a7b56e

Browse files
authored
[MLIR][Arith] Add denormal attribute to binary/unary operations (#112700)
Add support for denormal in the Arith dialect (binary and unary operations). Denormal are attached to every operation, and they can be of three different kinds: 1) ieee, denormal are preserved and processed as defined by IEEE 754 rules. 2) preserve sign, a mode where denormal numbers are flushed to zero, but the sign of the zero (+0 or -0) is preserved. 3) positive zero, a mode where all denormal numbers are flushed to positive zero (+0), ignoring the sign of the original number. Denormal refers to both the operands and the result. Currently only lowering for ieee is supported.
1 parent 45fdb77 commit 4a7b56e

File tree

13 files changed

+263
-50
lines changed

13 files changed

+263
-50
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
5151
template <typename SourceOp, typename TargetOp>
5252
class AttrConvertFastMathToLLVM {
5353
public:
54-
AttrConvertFastMathToLLVM(SourceOp srcOp) {
54+
explicit AttrConvertFastMathToLLVM(SourceOp srcOp) {
5555
// Copy the source attributes.
5656
convertedAttr = NamedAttrList{srcOp->getAttrs()};
5757
// Get the name of the arith fastmath attribute.
@@ -81,7 +81,7 @@ class AttrConvertFastMathToLLVM {
8181
template <typename SourceOp, typename TargetOp>
8282
class AttrConvertOverflowToLLVM {
8383
public:
84-
AttrConvertOverflowToLLVM(SourceOp srcOp) {
84+
explicit AttrConvertOverflowToLLVM(SourceOp srcOp) {
8585
// Copy the source attributes.
8686
convertedAttr = NamedAttrList{srcOp->getAttrs()};
8787
// Get the name of the arith overflow attribute.
@@ -109,7 +109,7 @@ class AttrConverterConstrainedFPToLLVM {
109109
"LLVM::FPExceptionBehaviorOpInterface");
110110

111111
public:
112-
AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
112+
explicit AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
113113
// Copy the source attributes.
114114
convertedAttr = NamedAttrList{srcOp->getAttrs()};
115115

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,37 @@ def Arith_RoundingModeAttr : I32EnumAttr<
181181
let cppNamespace = "::mlir::arith";
182182
}
183183

184+
//===----------------------------------------------------------------------===//
185+
// Arith_DenormalMode
186+
//===----------------------------------------------------------------------===//
187+
188+
// Denormal mode is applied on operands and results. For example, if denormal =
189+
// preserve_sign, operands and results will be flushed to sign preserving zero.
190+
// We do not distinguish between operands and results.
191+
192+
// The default mode. Denormals are preserved and processed as defined
193+
// by IEEE 754 rules.
194+
def Arith_DenormalModeIEEE : I32EnumAttrCase<"ieee", 0>;
195+
196+
// A mode where denormal numbers are flushed to zero, but the sign of the zero
197+
// (+0 or -0) is preserved.
198+
def Arith_DenormalModePreserveSign : I32EnumAttrCase<"preserve_sign", 1>;
199+
200+
// A mode where all denormal numbers are flushed to positive zero (+0),
201+
// ignoring the sign of the original number.
202+
def Arith_DenormalModePositiveZero : I32EnumAttrCase<"positive_zero", 2>;
203+
204+
def Arith_DenormalMode : I32EnumAttr<
205+
"DenormalMode", "denormal mode arith",
206+
[Arith_DenormalModeIEEE, Arith_DenormalModePreserveSign,
207+
Arith_DenormalModePositiveZero]> {
208+
let cppNamespace = "::mlir::arith";
209+
let genSpecializedAttr = 0;
210+
}
211+
212+
def Arith_DenormalModeAttr :
213+
EnumAttr<Arith_Dialect, Arith_DenormalMode, "denormal"> {
214+
let assemblyFormat = "`<` $value `>`";
215+
}
216+
184217
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,26 +61,35 @@ class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
6161
// Base class for floating point unary operations.
6262
class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
6363
Arith_UnaryOp<mnemonic,
64-
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
64+
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>,
65+
DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
6566
traits)>,
6667
Arguments<(ins FloatLike:$operand,
6768
DefaultValuedAttr<
68-
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
69+
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
70+
DefaultValuedAttr<
71+
Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
6972
Results<(outs FloatLike:$result)> {
7073
let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
74+
(`denormal` `` $denormal^)?
7175
attr-dict `:` type($result) }];
7276
}
7377

7478
// Base class for floating point binary operations.
7579
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
7680
Arith_BinaryOp<mnemonic,
77-
!listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>],
81+
!listconcat([Pure,
82+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
83+
DeclareOpInterfaceMethods<ArithDenormalModeInterface>],
7884
traits)>,
7985
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
8086
DefaultValuedAttr<
81-
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
87+
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
88+
DefaultValuedAttr<
89+
Arith_DenormalModeAttr, "::mlir::arith::DenormalMode::ieee">:$denormal)>,
8290
Results<(outs FloatLike:$result)> {
83-
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
91+
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
92+
(`denormal` `` $denormal^)?
8493
attr-dict `:` type($result) }];
8594
}
8695

@@ -1085,7 +1094,6 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
10851094
let hasFolder = 1;
10861095
}
10871096

1088-
10891097
//===----------------------------------------------------------------------===//
10901098
// MulFOp
10911099
//===----------------------------------------------------------------------===//
@@ -1111,8 +1119,6 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
11111119
%x = arith.mulf %y, %z : tensor<4x?xbf16>
11121120
```
11131121

1114-
TODO: In the distant future, this will accept optional attributes for fast
1115-
math, contraction, rounding mode, and other controls.
11161122
}];
11171123
let hasFolder = 1;
11181124
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,12 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
4545
return "fastmath";
4646
}]
4747
>
48-
4948
];
5049
}
5150

5251
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
5352
let description = [{
54-
Access to op integer overflow flags.
53+
Access to operation integer overflow flags.
5554
}];
5655

5756
let cppNamespace = "::mlir::arith";
@@ -108,7 +107,7 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
108107

109108
def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
110109
let description = [{
111-
Access to op rounding mode.
110+
Access to operation rounding mode.
112111
}];
113112

114113
let cppNamespace = "::mlir::arith";
@@ -139,4 +138,39 @@ def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
139138
];
140139
}
141140

141+
142+
def ArithDenormalModeInterface : OpInterface<"ArithDenormalModeInterface"> {
143+
let description = [{
144+
Access the operation denormal modes.
145+
}];
146+
147+
let cppNamespace = "::mlir::arith";
148+
149+
let methods = [
150+
InterfaceMethod<
151+
/*desc=*/ "Returns a DenormalModeAttr attribute for the operation",
152+
/*returnType=*/ "DenormalModeAttr",
153+
/*methodName=*/ "getDenormalModeAttr",
154+
/*args=*/ (ins),
155+
/*methodBody=*/ [{}],
156+
/*defaultImpl=*/ [{
157+
auto op = cast<ConcreteOp>(this->getOperation());
158+
return op.getDenormalAttr();
159+
}]
160+
>,
161+
StaticInterfaceMethod<
162+
/*desc=*/ [{Returns the name of the DenormalModeAttr attribute for
163+
the operation}],
164+
/*returnType=*/ "StringRef",
165+
/*methodName=*/ "getDenormalModeAttrName",
166+
/*args=*/ (ins),
167+
/*methodBody=*/ [{}],
168+
/*defaultImpl=*/ [{
169+
return "denormal";
170+
}]
171+
>
172+
];
173+
}
174+
175+
142176
#endif // ARITH_OPS_INTERFACES

mlir/include/mlir/IR/Matchers.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,12 @@ inline detail::constant_float_predicate_matcher m_NegInfFloat() {
433433
}};
434434
}
435435

436+
/// Matches a constant scalar / vector splat / tensor splat with denormal
437+
/// values.
438+
inline detail::constant_float_predicate_matcher m_isDenormalFloat() {
439+
return {[](const APFloat &value) { return value.isDenormal(); }};
440+
}
441+
436442
/// Matches a constant scalar / vector splat / tensor splat integer zero.
437443
inline detail::constant_int_predicate_matcher m_Zero() {
438444
return {[](const APInt &value) { return 0 == value; }};

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,49 @@ struct ConstrainedVectorConvertToLLVMPattern
5353
}
5454
};
5555

56+
template <typename SourceOp, typename TargetOp,
57+
template <typename, typename> typename AttrConvert =
58+
AttrConvertPassThrough>
59+
struct DenormalOpConversionToLLVMPattern
60+
: public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
61+
using VectorConvertToLLVMPattern<SourceOp, TargetOp,
62+
AttrConvert>::VectorConvertToLLVMPattern;
63+
64+
LogicalResult
65+
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
66+
ConversionPatternRewriter &rewriter) const override {
67+
// TODO: Here, we need a legalization step. LLVM provides a function-level
68+
// attribute for denormal; here, we need to move this information from the
69+
// operation to the function, making sure all the operations in the same
70+
// function are consistent.
71+
if (op.getDenormalModeAttr().getValue() != arith::DenormalMode::ieee)
72+
return rewriter.notifyMatchFailure(
73+
op, "only ieee denormal mode is supported at the moment");
74+
75+
StringRef arithDenormalAttrName = SourceOp::getDenormalModeAttrName();
76+
op->removeAttr(arithDenormalAttrName);
77+
return VectorConvertToLLVMPattern<SourceOp, TargetOp,
78+
AttrConvert>::matchAndRewrite(op, adaptor,
79+
rewriter);
80+
}
81+
};
82+
5683
//===----------------------------------------------------------------------===//
5784
// Straightforward Op Lowerings
5885
//===----------------------------------------------------------------------===//
5986

6087
using AddFOpLowering =
61-
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
62-
arith::AttrConvertFastMathToLLVM>;
88+
DenormalOpConversionToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
89+
arith::AttrConvertFastMathToLLVM>;
6390
using AddIOpLowering =
6491
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
6592
arith::AttrConvertOverflowToLLVM>;
6693
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
6794
using BitcastOpLowering =
6895
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
6996
using DivFOpLowering =
70-
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
71-
arith::AttrConvertFastMathToLLVM>;
97+
DenormalOpConversionToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
98+
arith::AttrConvertFastMathToLLVM>;
7299
using DivSIOpLowering =
73100
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
74101
using DivUIOpLowering =
@@ -83,38 +110,38 @@ using FPToSIOpLowering =
83110
using FPToUIOpLowering =
84111
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
85112
using MaximumFOpLowering =
86-
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
87-
arith::AttrConvertFastMathToLLVM>;
113+
DenormalOpConversionToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
114+
arith::AttrConvertFastMathToLLVM>;
88115
using MaxNumFOpLowering =
89-
VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
90-
arith::AttrConvertFastMathToLLVM>;
116+
DenormalOpConversionToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
117+
arith::AttrConvertFastMathToLLVM>;
91118
using MaxSIOpLowering =
92119
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
93120
using MaxUIOpLowering =
94121
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
95122
using MinimumFOpLowering =
96-
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
97-
arith::AttrConvertFastMathToLLVM>;
123+
DenormalOpConversionToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
124+
arith::AttrConvertFastMathToLLVM>;
98125
using MinNumFOpLowering =
99-
VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
100-
arith::AttrConvertFastMathToLLVM>;
126+
DenormalOpConversionToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
127+
arith::AttrConvertFastMathToLLVM>;
101128
using MinSIOpLowering =
102129
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
103130
using MinUIOpLowering =
104131
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
105132
using MulFOpLowering =
106-
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
107-
arith::AttrConvertFastMathToLLVM>;
133+
DenormalOpConversionToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
134+
arith::AttrConvertFastMathToLLVM>;
108135
using MulIOpLowering =
109136
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
110137
arith::AttrConvertOverflowToLLVM>;
111138
using NegFOpLowering =
112-
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
113-
arith::AttrConvertFastMathToLLVM>;
139+
DenormalOpConversionToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
140+
arith::AttrConvertFastMathToLLVM>;
114141
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
115142
using RemFOpLowering =
116-
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
117-
arith::AttrConvertFastMathToLLVM>;
143+
DenormalOpConversionToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
144+
arith::AttrConvertFastMathToLLVM>;
118145
using RemSIOpLowering =
119146
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
120147
using RemUIOpLowering =
@@ -131,8 +158,8 @@ using ShRUIOpLowering =
131158
using SIToFPOpLowering =
132159
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
133160
using SubFOpLowering =
134-
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
135-
arith::AttrConvertFastMathToLLVM>;
161+
DenormalOpConversionToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
162+
arith::AttrConvertFastMathToLLVM>;
136163
using SubIOpLowering =
137164
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
138165
arith::AttrConvertOverflowToLLVM>;

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,21 +422,23 @@ def TruncIShrUIMulIToMulUIExtended :
422422
//===----------------------------------------------------------------------===//
423423

424424
// mulf(negf(x), negf(y)) -> mulf(x,y)
425-
// (retain fastmath flags of original mulf)
425+
// (retain fastmath flags and denormal mode of the original divf)
426426
def MulFOfNegF :
427-
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
428-
(Arith_MulFOp $x, $y, $fmf),
427+
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_, $_),
428+
(Arith_NegFOp $y, $_, $_), $fmf, $mode),
429+
(Arith_MulFOp $x, $y, $fmf, $mode),
429430
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
430431

431432
//===----------------------------------------------------------------------===//
432433
// DivFOp
433434
//===----------------------------------------------------------------------===//
434435

435436
// divf(negf(x), negf(y)) -> divf(x,y)
436-
// (retain fastmath flags of original divf)
437+
// (retain fastmath flags and denormal mode of the original divf)
437438
def DivFOfNegF :
438-
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
439-
(Arith_DivFOp $x, $y, $fmf),
439+
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_, $_),
440+
(Arith_NegFOp $y, $_, $_), $fmf, $mode),
441+
(Arith_DivFOp $x, $y, $fmf, $mode),
440442
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
441443

442444
#endif // ARITH_PATTERNS

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,7 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
952952
//===----------------------------------------------------------------------===//
953953

954954
OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
955-
/// negf(negf(x)) -> x
955+
// negf(negf(x)) -> x
956956
if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
957957
return op.getOperand();
958958
return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
@@ -982,6 +982,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
982982
if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
983983
return getLhs();
984984

985+
// Simplifies subf(x, rhs) to x if the following conditions are met:
986+
// 1. `rhs` is a denormal floating-point value.
987+
// 2. The denormal mode for the operation is set to positive zero.
988+
bool isPositiveZeroMode =
989+
getDenormalModeAttr().getValue() == DenormalMode::positive_zero;
990+
if (isPositiveZeroMode && matchPattern(adaptor.getRhs(), m_isDenormalFloat()))
991+
return getLhs();
992+
985993
return constFoldBinaryOp<FloatAttr>(
986994
adaptor.getOperands(),
987995
[](const APFloat &a, const APFloat &b) { return a - b; });

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,15 +1498,17 @@ static Operation *findPayloadOp(Block *body, bool initFirst = false) {
14981498

14991499
void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
15001500
SmallVector<StringRef> elidedAttrs;
1501-
std::string attrToElide;
15021501
p << " { " << payloadOp->getName().getStringRef();
15031502
for (const auto &attr : payloadOp->getAttrs()) {
1504-
auto fastAttr =
1505-
llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1506-
if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1507-
attrToElide = attr.getName().str();
1508-
elidedAttrs.push_back(attrToElide);
1509-
break;
1503+
if (auto fastAttr = dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
1504+
if (fastAttr.getValue() == arith::FastMathFlags::none) {
1505+
elidedAttrs.push_back(attr.getName());
1506+
}
1507+
}
1508+
if (auto denormAttr = dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
1509+
if (denormAttr.getValue() == arith::DenormalMode::ieee) {
1510+
elidedAttrs.push_back(attr.getName());
1511+
}
15101512
}
15111513
}
15121514
p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);

0 commit comments

Comments
 (0)