Skip to content

Commit c254b0b

Browse files
committed
[MLIR] Introduce std.global_memref and std.get_global_memref operations.
- Add standard dialect operations to define global variables with memref types and to retrieve the memref for to a named global variable - Extend unit tests to test verification for these operations. Differential Revision: https://reviews.llvm.org/D90337
1 parent 934b27a commit c254b0b

File tree

8 files changed

+336
-15
lines changed

8 files changed

+336
-15
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,97 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
20052005
let hasFolder = 0;
20062006
}
20072007

2008+
//===----------------------------------------------------------------------===//
2009+
// GlobalMemrefOp
2010+
//===----------------------------------------------------------------------===//
2011+
2012+
def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> {
2013+
let summary = "declare or define a global memref variable";
2014+
let description = [{
2015+
The `global_memref` operation declares or defines a named global variable.
2016+
The backing memory for the variable is allocated statically and is described
2017+
by the type of the variable (which should be a statically shaped memref
2018+
type). The operation is a declaration if no `inital_value` is specified,
2019+
else it is a definition. The `initial_value` can either be a unit attribute
2020+
to represent a definition of an uninitialized global variable, or an
2021+
elements attribute to represent the definition of a global variable with an
2022+
initial value. The global variable can also be marked constant using the
2023+
`constant` unit attribute. Writing to such constant global variables is
2024+
undefined.
2025+
2026+
The global variable can be accessed by using the `get_global_memref` to
2027+
retrieve the memref for the global variable. Note that the memref
2028+
for such global variable itself is immutable (i.e., get_global_memref for a
2029+
given global variable will always return the same memref descriptor).
2030+
2031+
Example:
2032+
2033+
```mlir
2034+
// Private variable with an initial value.
2035+
global_memref @x : memref<2xf32> { sym_visibility = "private",
2036+
initial_value = dense<0.0,2.0> : tensor<2xf32> }
2037+
2038+
// External variable.
2039+
global_memref @y : memref<4xi32> { sym_visibility = "public" }
2040+
2041+
// Uninitialized externally visible variable.
2042+
global_memref @z : memref<3xf16> { sym_visibility = "public",
2043+
initial_value }
2044+
```
2045+
}];
2046+
2047+
let arguments = (ins
2048+
SymbolNameAttr:$sym_name,
2049+
OptionalAttr<StrAttr>:$sym_visibility,
2050+
TypeAttr:$type,
2051+
OptionalAttr<AnyAttr>:$initial_value,
2052+
UnitAttr:$constant
2053+
);
2054+
2055+
let assemblyFormat = [{
2056+
($sym_visibility^)?
2057+
(`constant` $constant^)?
2058+
$sym_name `:`
2059+
custom<GlobalMemrefOpTypeAndInitialValue>($type, $initial_value)
2060+
attr-dict
2061+
}];
2062+
2063+
let extraClassDeclaration = [{
2064+
bool isExternal() { return !initial_value(); }
2065+
bool isUnitialized() {
2066+
return !isExternal() && initial_value().getValue().isa<UnitAttr>();
2067+
}
2068+
}];
2069+
}
2070+
2071+
//===----------------------------------------------------------------------===//
2072+
// GetGlobalMemrefOp
2073+
//===----------------------------------------------------------------------===//
2074+
2075+
def GetGlobalMemrefOp : Std_Op<"get_global_memref",
2076+
[NoSideEffect, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
2077+
let summary = "get the memref pointing to a global variable";
2078+
let description = [{
2079+
The `get_global_memref` operation retrieves the memref pointing to a
2080+
named global variable. If the global variable is marked constant, writing
2081+
to the result memref (such as through a `std.store` operation) is
2082+
undefined.
2083+
2084+
Example:
2085+
2086+
```mlir
2087+
%x = get_global_memref @foo : memref<2xf32>
2088+
```
2089+
}];
2090+
2091+
let arguments = (ins FlatSymbolRefAttr:$name);
2092+
let results = (outs AnyStaticShapeMemRef:$result);
2093+
let assemblyFormat = "$name `:` type($result) attr-dict";
2094+
2095+
// `GetGlobalMemrefOp` is fully verified by its traits.
2096+
let verifier = ?;
2097+
}
2098+
20082099
//===----------------------------------------------------------------------===//
20092100
// ImOp
20102101
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ class OpAsmParser {
395395

396396
// Parse any kind of attribute.
397397
Attribute attr;
398-
if (parseAttribute(attr))
398+
if (parseAttribute(attr, type))
399399
return failure();
400400

401401
// Check for the right kind of attribute.
@@ -436,6 +436,10 @@ class OpAsmParser {
436436
Type type,
437437
StringRef attrName,
438438
NamedAttrList &attrs) = 0;
439+
virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
440+
Type type,
441+
StringRef attrName,
442+
NamedAttrList &attrs) = 0;
439443

440444
/// Parse an arbitrary attribute of a given type and return it in result. This
441445
/// also adds the attribute to the specified attribute list with the specified

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,18 @@ static bool areVectorCastSimpleCompatible(
245245
return false;
246246
}
247247

248+
//===----------------------------------------------------------------------===//
249+
// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp
250+
//===----------------------------------------------------------------------===//
251+
252+
static Type getTensorTypeFromMemRefType(Type type) {
253+
if (auto memref = type.dyn_cast<MemRefType>())
254+
return RankedTensorType::get(memref.getShape(), memref.getElementType());
255+
if (auto memref = type.dyn_cast<UnrankedMemRefType>())
256+
return UnrankedTensorType::get(memref.getElementType());
257+
return NoneType::get(type.getContext());
258+
}
259+
248260
//===----------------------------------------------------------------------===//
249261
// AddFOp
250262
//===----------------------------------------------------------------------===//
@@ -2140,6 +2152,106 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
21402152
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
21412153
}
21422154

2155+
//===----------------------------------------------------------------------===//
2156+
// GlobalMemrefOp
2157+
//===----------------------------------------------------------------------===//
2158+
2159+
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p,
2160+
GlobalMemrefOp op,
2161+
TypeAttr type,
2162+
Attribute initialValue) {
2163+
p << type;
2164+
if (!op.isExternal()) {
2165+
p << " = ";
2166+
if (op.isUnitialized())
2167+
p << "uninitialized";
2168+
else
2169+
p.printAttributeWithoutType(initialValue);
2170+
}
2171+
}
2172+
2173+
static ParseResult
2174+
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
2175+
Attribute &initialValue) {
2176+
Type type;
2177+
if (parser.parseType(type))
2178+
return failure();
2179+
2180+
auto memrefType = type.dyn_cast<MemRefType>();
2181+
if (!memrefType || !memrefType.hasStaticShape())
2182+
return parser.emitError(parser.getNameLoc())
2183+
<< "type should be static shaped memref, but got " << type;
2184+
typeAttr = TypeAttr::get(type);
2185+
2186+
if (parser.parseOptionalEqual())
2187+
return success();
2188+
2189+
if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
2190+
initialValue = UnitAttr::get(parser.getBuilder().getContext());
2191+
return success();
2192+
}
2193+
2194+
Type tensorType = getTensorTypeFromMemRefType(memrefType);
2195+
if (parser.parseAttribute(initialValue, tensorType))
2196+
return failure();
2197+
if (!initialValue.isa<ElementsAttr>())
2198+
return parser.emitError(parser.getNameLoc())
2199+
<< "initial value should be a unit or elements attribute";
2200+
return success();
2201+
}
2202+
2203+
static LogicalResult verify(GlobalMemrefOp op) {
2204+
auto memrefType = op.type().dyn_cast<MemRefType>();
2205+
if (!memrefType || !memrefType.hasStaticShape())
2206+
return op.emitOpError("type should be static shaped memref, but got ")
2207+
<< op.type();
2208+
2209+
// Verify that the initial value, if present, is either a unit attribute or
2210+
// an elements attribute.
2211+
if (op.initial_value().hasValue()) {
2212+
Attribute initValue = op.initial_value().getValue();
2213+
if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
2214+
return op.emitOpError("initial value should be a unit or elements "
2215+
"attribute, but got ")
2216+
<< initValue;
2217+
2218+
// Check that the type of the initial value is compatible with the type of
2219+
// the global variable.
2220+
if (initValue.isa<ElementsAttr>()) {
2221+
Type initType = initValue.getType();
2222+
Type tensorType = getTensorTypeFromMemRefType(memrefType);
2223+
if (initType != tensorType)
2224+
return op.emitOpError("initial value expected to be of type ")
2225+
<< tensorType << ", but was of type " << initType;
2226+
}
2227+
}
2228+
2229+
// TODO: verify visibility for declarations.
2230+
return success();
2231+
}
2232+
2233+
//===----------------------------------------------------------------------===//
2234+
// GetGlobalMemrefOp
2235+
//===----------------------------------------------------------------------===//
2236+
2237+
LogicalResult
2238+
GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2239+
// Verify that the result type is same as the type of the referenced
2240+
// global_memref op.
2241+
auto global =
2242+
symbolTable.lookupNearestSymbolFrom<GlobalMemrefOp>(*this, nameAttr());
2243+
if (!global)
2244+
return emitOpError("'")
2245+
<< name() << "' does not reference a valid global memref";
2246+
2247+
Type resultType = result().getType();
2248+
if (global.type() != resultType)
2249+
return emitOpError("result type ")
2250+
<< resultType << " does not match type " << global.type()
2251+
<< " of the global memref @" << name();
2252+
return success();
2253+
}
2254+
21432255
//===----------------------------------------------------------------------===//
21442256
// IndexCastOp
21452257
//===----------------------------------------------------------------------===//
@@ -3891,18 +4003,6 @@ void TensorCastOp::getCanonicalizationPatterns(
38914003
results.insert<ChainedTensorCast>(context);
38924004
}
38934005

3894-
//===----------------------------------------------------------------------===//
3895-
// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp
3896-
//===----------------------------------------------------------------------===//
3897-
3898-
static Type getTensorTypeFromMemRefType(Type type) {
3899-
if (auto memref = type.dyn_cast<MemRefType>())
3900-
return RankedTensorType::get(memref.getShape(), memref.getElementType());
3901-
if (auto memref = type.dyn_cast<UnrankedMemRefType>())
3902-
return UnrankedTensorType::get(memref.getElementType());
3903-
return NoneType::get(type.getContext());
3904-
}
3905-
39064006
//===----------------------------------------------------------------------===//
39074007
// TensorLoadOp
39084008
//===----------------------------------------------------------------------===//

mlir/lib/Parser/AttributeParser.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
226226
Type type) {
227227
return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
228228
}
229+
OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
230+
Type type) {
231+
return parseOptionalAttributeWithToken(Token::string, attribute, type);
232+
}
229233

230234
/// Attribute dictionary.
231235
///
@@ -807,6 +811,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
807811

808812
/// Parse a dense elements attribute.
809813
Attribute Parser::parseDenseElementsAttr(Type attrType) {
814+
auto attribLoc = getToken().getLoc();
810815
consumeToken(Token::kw_dense);
811816
if (parseToken(Token::less, "expected '<' after 'dense'"))
812817
return nullptr;
@@ -819,11 +824,14 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
819824
return nullptr;
820825
}
821826

822-
auto typeLoc = getToken().getLoc();
827+
// If the type is specified `parseElementsLiteralType` will not parse a type.
828+
// Use the attribute location as the location for error reporting in that
829+
// case.
830+
auto loc = attrType ? attribLoc : getToken().getLoc();
823831
auto type = parseElementsLiteralType(attrType);
824832
if (!type)
825833
return nullptr;
826-
return literalParser.getAttr(typeLoc, type);
834+
return literalParser.getAttr(loc, type);
827835
}
828836

829837
/// Parse an opaque elements attribute.

mlir/lib/Parser/Parser.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,11 @@ class CustomOpAsmParser : public OpAsmParser {
10651065
NamedAttrList &attrs) override {
10661066
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
10671067
}
1068+
OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
1069+
StringRef attrName,
1070+
NamedAttrList &attrs) override {
1071+
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
1072+
}
10681073

10691074
/// Parse a named dictionary into 'result' if it is present.
10701075
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {

mlir/lib/Parser/Parser.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class Parser {
188188
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
189189
Type type = {});
190190
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
191+
OptionalParseResult parseOptionalAttribute(StringAttr &attribute, Type type);
191192

192193
/// Parse an optional attribute that is demarcated by a specific token.
193194
template <typename AttributeT>

0 commit comments

Comments
 (0)