Skip to content

Commit 7fe2294

Browse files
author
Jeff Niu
committed
[mlir][ods] Allow specifying return types of builders
This patch allows custom attribute and type builders to return something other than the C++ type of the attribute or type. This is useful for attributes or types that may perform extra work during construction (e.g. canonicalization) that could result in a different kind of attribute or type being returned. Reviewed By: rriddle, lattner Differential Revision: https://reviews.llvm.org/D129792
1 parent a7789d6 commit 7fe2294

File tree

11 files changed

+108
-33
lines changed

11 files changed

+108
-33
lines changed

Diff for: mlir/docs/AttributesAndTypes.md

+23
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def MyType : ... {
347347
// its arguments.
348348
return Base::get(typeParam.getContext(), ...);
349349
}]>,
350+
TypeBuilder<(ins "int":$intParam), [{}], "IntegerType">,
350351
];
351352
}
352353
```
@@ -461,6 +462,28 @@ the builder used `TypeBuilderWithInferredContext` implies that the context
461462
parameter is not necessary as it can be inferred from the arguments to the
462463
builder.
463464
465+
The fifth builder will generate the declaration of a builder method with a
466+
custom return type, like:
467+
468+
```tablegen
469+
let builders = [
470+
TypeBuilder<(ins "int":$intParam), [{}], "IntegerType">,
471+
]
472+
```
473+
474+
```c++
475+
class MyType : /*...*/ {
476+
/*...*/
477+
static IntegerType get(::mlir::MLIRContext *context, int intParam);
478+
479+
};
480+
```
481+
482+
This generates a builder declaration the same as the first three examples, but
483+
the return type of the builder is user-specified instead of the attribute or
484+
type class. This is useful for defining builders of attributes and types that
485+
may fold or canonicalize on construction.
486+
464487
### Parsing and Printing
465488
466489
If a mnemonic was specified, the `hasCustomAssemblyFormat` and `assemblyFormat`

Diff for: mlir/include/mlir/IR/AttrTypeBase.td

+19-11
Original file line numberDiff line numberDiff line change
@@ -96,30 +96,38 @@ class PredTypeTrait<string descr, Pred pred> : PredTrait<descr, pred>;
9696
// This is necessary because the `body` is also used to generate `getChecked`
9797
// methods, which have a different underlying `Base::get*` call.
9898
//
99-
class AttrOrTypeBuilder<dag parameters, code bodyCode = ""> {
99+
class AttrOrTypeBuilder<dag parameters, code bodyCode = "",
100+
string returnTypeStr = ""> {
100101
dag dagParams = parameters;
101102
code body = bodyCode;
102103

104+
// Change the return type of the builder. By default, it is the type of the
105+
// attribute or type.
106+
string returnType = returnTypeStr;
107+
103108
// The context parameter can be inferred from one of the other parameters and
104109
// is not implicitly added to the parameter list.
105110
bit hasInferredContextParam = 0;
106111
}
107-
class AttrBuilder<dag parameters, code bodyCode = "">
108-
: AttrOrTypeBuilder<parameters, bodyCode>;
109-
class TypeBuilder<dag parameters, code bodyCode = "">
110-
: AttrOrTypeBuilder<parameters, bodyCode>;
112+
class AttrBuilder<dag parameters, code bodyCode = "", string returnType = "">
113+
: AttrOrTypeBuilder<parameters, bodyCode, returnType>;
114+
class TypeBuilder<dag parameters, code bodyCode = "", string returnType = "">
115+
: AttrOrTypeBuilder<parameters, bodyCode, returnType>;
111116

112117
// A class of AttrOrTypeBuilder that is able to infer the MLIRContext parameter
113118
// from one of the other builder parameters. Instances of this builder do not
114119
// have `MLIRContext *` implicitly added to the parameter list.
115-
class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
116-
: TypeBuilder<parameters, bodyCode> {
120+
class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
121+
string returnType = "">
122+
: TypeBuilder<parameters, bodyCode, returnType> {
117123
let hasInferredContextParam = 1;
118124
}
119-
class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "">
120-
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
121-
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
122-
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
125+
class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "",
126+
string returnType = "">
127+
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode, returnType>;
128+
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
129+
string returnType = "">
130+
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode, returnType>;
123131

124132
//===----------------------------------------------------------------------===//
125133
// Definitions

Diff for: mlir/include/mlir/IR/OpImplementation.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -792,14 +792,14 @@ class AsmParser {
792792
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
793793
/// context parameter.
794794
template <typename T, typename... ParamsT>
795-
T getChecked(SMLoc loc, ParamsT &&...params) {
795+
auto getChecked(SMLoc loc, ParamsT &&...params) {
796796
return T::getChecked([&] { return emitError(loc); },
797797
std::forward<ParamsT>(params)...);
798798
}
799799
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
800800
/// errors.
801801
template <typename T, typename... ParamsT>
802-
T getChecked(ParamsT &&...params) {
802+
auto getChecked(ParamsT &&...params) {
803803
return T::getChecked([&] { return emitError(getNameLoc()); },
804804
std::forward<ParamsT>(params)...);
805805
}

Diff for: mlir/include/mlir/TableGen/AttrOrTypeDef.h

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class AttrOrTypeBuilder : public Builder {
3737
public:
3838
using Builder::Builder;
3939

40+
/// Returns an optional builder return type.
41+
Optional<StringRef> getReturnType() const;
42+
4043
/// Returns true if this builder is able to infer the MLIRContext parameter.
4144
bool hasInferredContextParameter() const;
4245
};

Diff for: mlir/lib/TableGen/AttrOrTypeDef.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ using namespace mlir::tblgen;
2020
// AttrOrTypeBuilder
2121
//===----------------------------------------------------------------------===//
2222

23-
/// Returns true if this builder is able to infer the MLIRContext parameter.
23+
Optional<StringRef> AttrOrTypeBuilder::getReturnType() const {
24+
Optional<StringRef> type = def->getValueAsOptionalString("returnType");
25+
return type && !type->empty() ? type : llvm::None;
26+
}
27+
2428
bool AttrOrTypeBuilder::hasInferredContextParameter() const {
2529
return def->getValueAsBit("hasInferredContextParam");
2630
}

Diff for: mlir/test/lib/Dialect/Test/TestAttrDefs.td

+13
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,19 @@ def TestAttrSelfTypeParameterFormat
223223
let assemblyFormat = "`<` $a `>`";
224224
}
225225

226+
// Test overridding attribute builders with a custom builder.
227+
def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
228+
let mnemonic = "override_builder";
229+
let parameters = (ins "int":$a);
230+
let assemblyFormat = "`<` $a `>`";
231+
232+
let skipDefaultBuilders = 1;
233+
let genVerifyDecl = 1;
234+
let builders = [AttrBuilder<(ins "int":$a), [{
235+
return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
236+
}], "::mlir::Attribute">];
237+
}
238+
226239
// Test simple extern 1D vector using ElementsAttrInterface.
227240
def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
228241
ElementsAttrInterface

Diff for: mlir/test/mlir-tblgen/attr-or-type-format.td

+14-14
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
5555
// ATTR: if (odsParser.parseRParen())
5656
// ATTR: return {};
5757
// ATTR: return TestAAttr::get(odsParser.getContext(),
58-
// ATTR: (*_result_value),
59-
// ATTR: (*_result_complex));
58+
// ATTR: IntegerAttr((*_result_value)),
59+
// ATTR: TestParamA((*_result_complex)));
6060
// ATTR: }
6161

6262
// ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -114,8 +114,8 @@ def AttrA : TestAttr<"TestA"> {
114114
// ATTR: return {};
115115
// ATTR: }
116116
// ATTR: return TestBAttr::get(odsParser.getContext(),
117-
// ATTR: (*_result_v0),
118-
// ATTR: (*_result_v1));
117+
// ATTR: TestParamA((*_result_v0)),
118+
// ATTR: TestParamB((*_result_v1)));
119119
// ATTR: }
120120

121121
// ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -151,8 +151,8 @@ def AttrB : TestAttr<"TestB"> {
151151
// ATTR: if (::mlir::failed(_result_v1))
152152
// ATTR: return {};
153153
// ATTR: return TestFAttr::get(odsParser.getContext(),
154-
// ATTR: (*_result_v0),
155-
// ATTR: (*_result_v1));
154+
// ATTR: int((*_result_v0)),
155+
// ATTR: int((*_result_v1)));
156156
// ATTR: }
157157

158158
def AttrC : TestAttr<"TestF"> {
@@ -278,10 +278,10 @@ def TypeA : TestType<"TestC"> {
278278
// TYPE: if (::mlir::failed(_result_v3))
279279
// TYPE: return {};
280280
// TYPE: return TestDType::get(odsParser.getContext(),
281-
// TYPE: (*_result_v0),
282-
// TYPE: (*_result_v1),
283-
// TYPE: (*_result_v2),
284-
// TYPE: (*_result_v3));
281+
// TYPE: TestParamC((*_result_v0)),
282+
// TYPE: TestParamD((*_result_v1)),
283+
// TYPE: TestParamC((*_result_v2)),
284+
// TYPE: TestParamD((*_result_v3)));
285285
// TYPE: }
286286

287287
// TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -369,10 +369,10 @@ def TypeB : TestType<"TestD"> {
369369
// TYPE: return {};
370370
// TYPE: }
371371
// TYPE: return TestEType::get(odsParser.getContext(),
372-
// TYPE: (*_result_v0),
373-
// TYPE: (*_result_v1),
374-
// TYPE: (*_result_v2),
375-
// TYPE: (*_result_v3));
372+
// TYPE: IntegerAttr((*_result_v0)),
373+
// TYPE: IntegerAttr((*_result_v1)),
374+
// TYPE: IntegerAttr((*_result_v2)),
375+
// TYPE: IntegerAttr((*_result_v3)));
376376
// TYPE: }
377377

378378
// TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const {

Diff for: mlir/test/mlir-tblgen/attrdefs.td

+12-2
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ include "mlir/IR/OpBase.td"
3131
// DEF-NEXT: .Case(::test::IndexAttr::getMnemonic()
3232
// DEF-NEXT: value = ::test::IndexAttr::parse(parser, type);
3333
// DEF-NEXT: return ::mlir::success(!!value);
34-
// DEF: .Default([&](llvm::StringRef keyword,
34+
// DEF: .Default([&](llvm::StringRef keyword,
3535
// DEF-NEXT: *mnemonic = keyword;
36-
// DEF-NEXT: return llvm::None;
36+
// DEF-NEXT: return llvm::None;
3737

3838
def Test_Dialect: Dialect {
3939
// DECL-NOT: TestDialect
@@ -148,3 +148,13 @@ def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
148148
// DEF: ParamWithAccessorTypeAttrStorage
149149
// DEF: ParamWithAccessorTypeAttrStorage(std::string param)
150150
// DEF: StringRef ParamWithAccessorTypeAttr::getParam()
151+
152+
def G_BuilderWithReturnTypeAttr : TestAttr<"BuilderWithReturnType"> {
153+
let parameters = (ins "int":$a);
154+
let genVerifyDecl = 1;
155+
let builders = [AttrBuilder<(ins), [{ return {}; }], "::mlir::Attribute">];
156+
}
157+
158+
// DECL-LABEL: class BuilderWithReturnTypeAttr
159+
// DECL: ::mlir::Attribute get(
160+
// DECL: ::mlir::Attribute getChecked(

Diff for: mlir/test/mlir-tblgen/testdialect-attrdefs.mlir

+6
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,9 @@ func.func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [
1313
// CHECK-LABEL: @qualifiedAttr()
1414
// CHECK-SAME: #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>
1515
func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>}
16+
17+
// CHECK-LABEL: @overriddenAttr
18+
// CHECK-SAME: foo = 5 : index
19+
func.func private @overriddenAttr() attributes {
20+
foo = #test.override_builder<5>
21+
}

Diff for: mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,10 @@ getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
348348
void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
349349
// Don't emit a body if there isn't one.
350350
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
351-
Method *m = defCls.addMethod(def.getCppClassName(), "get", props,
351+
StringRef returnType = def.getCppClassName();
352+
if (Optional<StringRef> builderReturnType = builder.getReturnType())
353+
returnType = *builderReturnType;
354+
Method *m = defCls.addMethod(returnType, "get", props,
352355
getCustomBuilderParams({}, builder));
353356
if (!builder.getBody())
354357
return;
@@ -373,8 +376,11 @@ static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
373376
void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
374377
// Don't emit a body if there isn't one.
375378
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
379+
StringRef returnType = def.getCppClassName();
380+
if (Optional<StringRef> builderReturnType = builder.getReturnType())
381+
returnType = *builderReturnType;
376382
Method *m = defCls.addMethod(
377-
def.getCppClassName(), "getChecked", props,
383+
returnType, "getChecked", props,
378384
getCustomBuilderParams(
379385
{{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
380386
builder));

Diff for: mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ void DefFormat::genParser(MethodBody &os) {
311311
} else {
312312
selfOs << formatv("(*_result_{0})", param.getName());
313313
}
314-
os << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()));
314+
os << param.getCppType() << "("
315+
<< tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
316+
<< ")";
315317
}
316318
os << ");";
317319
}

0 commit comments

Comments
 (0)