Skip to content

Commit 690dc4e

Browse files
quartersdgjpienaar
andauthored
Add AsmParser::parseDecimalInteger. (llvm#96255)
An attribute parser needs to parse lists of possibly negative integers separated by x in a way which is foiled by parseInteger handling hex formats and parseIntegerInDimensionList does not allow negatives. --------- Co-authored-by: Jacques Pienaar <[email protected]>
1 parent 785d376 commit 690dc4e

File tree

7 files changed

+149
-6
lines changed

7 files changed

+149
-6
lines changed

mlir/include/mlir/IR/OpImplementation.h

+30-5
Original file line numberDiff line numberDiff line change
@@ -714,16 +714,27 @@ class AsmParser {
714714
return *parseResult;
715715
}
716716

717+
/// Parse a decimal integer value from the stream.
718+
template <typename IntT>
719+
ParseResult parseDecimalInteger(IntT &result) {
720+
auto loc = getCurrentLocation();
721+
OptionalParseResult parseResult = parseOptionalDecimalInteger(result);
722+
if (!parseResult.has_value())
723+
return emitError(loc, "expected decimal integer value");
724+
return *parseResult;
725+
}
726+
717727
/// Parse an optional integer value from the stream.
718728
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
729+
virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
719730

720-
template <typename IntT>
721-
OptionalParseResult parseOptionalInteger(IntT &result) {
731+
private:
732+
template <typename IntT, typename ParseFn>
733+
OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
734+
ParseFn &&parseFn) {
722735
auto loc = getCurrentLocation();
723-
724-
// Parse the unsigned variant.
725736
APInt uintResult;
726-
OptionalParseResult parseResult = parseOptionalInteger(uintResult);
737+
OptionalParseResult parseResult = parseFn(uintResult);
727738
if (!parseResult.has_value() || failed(*parseResult))
728739
return parseResult;
729740

@@ -737,6 +748,20 @@ class AsmParser {
737748
return success();
738749
}
739750

751+
public:
752+
template <typename IntT>
753+
OptionalParseResult parseOptionalInteger(IntT &result) {
754+
return parseOptionalIntegerAndCheck(
755+
result, [&](APInt &result) { return parseOptionalInteger(result); });
756+
}
757+
758+
template <typename IntT>
759+
OptionalParseResult parseOptionalDecimalInteger(IntT &result) {
760+
return parseOptionalIntegerAndCheck(result, [&](APInt &result) {
761+
return parseOptionalDecimalInteger(result);
762+
});
763+
}
764+
740765
/// These are the supported delimiters around operand lists and region
741766
/// argument lists, used by parseOperandList.
742767
enum class Delimiter {

mlir/lib/AsmParser/AsmParserImpl.h

+5
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ class AsmParserImpl : public BaseT {
322322
return parser.parseOptionalInteger(result);
323323
}
324324

325+
/// Parse an optional integer value from the stream.
326+
OptionalParseResult parseOptionalDecimalInteger(APInt &result) override {
327+
return parser.parseOptionalDecimalInteger(result);
328+
}
329+
325330
/// Parse a list of comma-separated items with an optional delimiter. If a
326331
/// delimiter is provided, then an empty list is allowed. If not, then at
327332
/// least one element will be parsed.

mlir/lib/AsmParser/Parser.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "llvm/ADT/STLExtras.h"
4242
#include "llvm/ADT/ScopeExit.h"
4343
#include "llvm/ADT/Sequence.h"
44+
#include "llvm/ADT/StringExtras.h"
4445
#include "llvm/ADT/StringMap.h"
4546
#include "llvm/ADT/StringSet.h"
4647
#include "llvm/Support/Alignment.h"
@@ -307,6 +308,45 @@ OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
307308
return success();
308309
}
309310

311+
/// Parse an optional integer value only in decimal format from the stream.
312+
OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
313+
Token curToken = getToken();
314+
if (curToken.isNot(Token::integer, Token::minus)) {
315+
return std::nullopt;
316+
}
317+
318+
bool negative = consumeIf(Token::minus);
319+
Token curTok = getToken();
320+
if (parseToken(Token::integer, "expected integer value")) {
321+
return failure();
322+
}
323+
324+
StringRef spelling = curTok.getSpelling();
325+
// If the integer is in hexadecimal return only the 0. The lexer has already
326+
// moved past the entire hexidecimal encoded integer so we reset the lex
327+
// pointer to just past the 0 we actualy want to consume.
328+
if (spelling[0] == '0' && spelling.size() > 1 &&
329+
llvm::toLower(spelling[1]) == 'x') {
330+
result = 0;
331+
state.lex.resetPointer(spelling.data() + 1);
332+
consumeToken();
333+
return success();
334+
}
335+
336+
if (spelling.getAsInteger(10, result))
337+
return emitError(curTok.getLoc(), "integer value too large");
338+
339+
// Make sure we have a zero at the top so we return the right signedness.
340+
if (result.isNegative())
341+
result = result.zext(result.getBitWidth() + 1);
342+
343+
// Process the negative sign if present.
344+
if (negative)
345+
result.negate();
346+
347+
return success();
348+
}
349+
310350
/// Parse a floating point value from an integer literal token.
311351
ParseResult Parser::parseFloatFromIntegerLiteral(
312352
std::optional<APFloat> &result, const Token &tok, bool isNegative,

mlir/lib/AsmParser/Parser.h

+3
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ class Parser {
144144
/// Parse an optional integer value from the stream.
145145
OptionalParseResult parseOptionalInteger(APInt &result);
146146

147+
/// Parse an optional integer value only in decimal format from the stream.
148+
OptionalParseResult parseOptionalDecimalInteger(APInt &result);
149+
147150
/// Parse a floating point value from an integer literal token.
148151
ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
149152
const Token &tok, bool isNegative,

mlir/test/lib/Dialect/Test/TestAttrDefs.td

+9
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
8181
let mnemonic = "attr_with_trait";
8282
}
8383

84+
// An attribute of a list of decimal formatted integers in similar format to shapes.
85+
def TestDecimalShapeAttr : Test_Attr<"TestDecimalShape"> {
86+
let mnemonic = "decimal_shape";
87+
88+
let parameters = (ins ArrayRefParameter<"int64_t">:$shape);
89+
90+
let hasCustomAssemblyFormat = 1;
91+
}
92+
8493
// Test support for ElementsAttrInterface.
8594
def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> {
8695
let mnemonic = "i64_elements";

mlir/test/lib/Dialect/Test/TestAttributes.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313

1414
#include "TestAttributes.h"
1515
#include "TestDialect.h"
16+
#include "TestTypes.h"
17+
#include "mlir/IR/Attributes.h"
1618
#include "mlir/IR/Builders.h"
1719
#include "mlir/IR/DialectImplementation.h"
1820
#include "mlir/IR/ExtensibleDialect.h"
21+
#include "mlir/IR/OpImplementation.h"
1922
#include "mlir/IR/Types.h"
2023
#include "llvm/ADT/APFloat.h"
2124
#include "llvm/ADT/Hashing.h"
@@ -63,6 +66,39 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
6366
// CompoundAAttr
6467
//===----------------------------------------------------------------------===//
6568

69+
Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
70+
if (parser.parseLess()){
71+
return Attribute();
72+
}
73+
SmallVector<int64_t> shape;
74+
if (parser.parseOptionalGreater()) {
75+
auto parseDecimal = [&]() {
76+
shape.emplace_back();
77+
auto parseResult = parser.parseOptionalDecimalInteger(shape.back());
78+
if (!parseResult.has_value() || failed(*parseResult)) {
79+
parser.emitError(parser.getCurrentLocation()) << "expected an integer";
80+
return failure();
81+
}
82+
return success();
83+
};
84+
if (failed(parseDecimal())) {
85+
return Attribute();
86+
}
87+
while (failed(parser.parseOptionalGreater())) {
88+
if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) {
89+
return Attribute();
90+
}
91+
}
92+
}
93+
return get(parser.getContext(), shape);
94+
}
95+
96+
void TestDecimalShapeAttr::print(AsmPrinter &printer) const {
97+
printer << "<";
98+
llvm::interleave(getShape(), printer, "x");
99+
printer << ">";
100+
}
101+
66102
Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
67103
SmallVector<uint64_t> elements;
68104
if (parser.parseLess() || parser.parseLSquare())

mlir/test/mlir-tblgen/testdialect-attrdefs.mlir

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
22

33
// CHECK-LABEL: func private @compoundA()
44
// CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]>
@@ -19,3 +19,28 @@ func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qu
1919
func.func private @overriddenAttr() attributes {
2020
foo = #test.override_builder<5>
2121
}
22+
23+
// CHECK-LABEL: @decimalIntegerShapeEmpty
24+
// CHECK-SAME: foo = #test.decimal_shape<>
25+
func.func private @decimalIntegerShapeEmpty() attributes {
26+
foo = #test.decimal_shape<>
27+
}
28+
29+
// CHECK-LABEL: @decimalIntegerShape
30+
// CHECK-SAME: foo = #test.decimal_shape<5>
31+
func.func private @decimalIntegerShape() attributes {
32+
foo = #test.decimal_shape<5>
33+
}
34+
35+
// CHECK-LABEL: @decimalIntegerShapeMultiple
36+
// CHECK-SAME: foo = #test.decimal_shape<0x3x7>
37+
func.func private @decimalIntegerShapeMultiple() attributes {
38+
foo = #test.decimal_shape<0x3x7>
39+
}
40+
41+
// -----
42+
43+
func.func private @hexdecimalInteger() attributes {
44+
// expected-error @below {{expected an integer}}
45+
sdg = #test.decimal_shape<1x0xb>
46+
}

0 commit comments

Comments
 (0)