Skip to content

Commit ec99d6e

Browse files
author
Denis Khalikov
committed
[mlir][spirv] Add a spirv::InterfaceVarABIAttr.
Summary: Add a proper dialect-specific attribute for interface variable ABI. Differential Revision: https://reviews.llvm.org/D77941
1 parent a9cb529 commit ec99d6e

File tree

12 files changed

+300
-82
lines changed

12 files changed

+300
-82
lines changed

mlir/docs/Dialects/SPIR-V.md

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -883,14 +883,30 @@ interfaces:
883883
* `spv.entry_point_abi` is a struct attribute that should be attached to the
884884
entry function. It contains:
885885
* `local_size` for specifying the local work group size for the dispatch.
886-
* `spv.interface_var_abi` is a struct attribute that should be attached to
887-
each operand and result of the entry function. It contains:
888-
* `descriptor_set` for specifying the descriptor set number for the
889-
corresponding resource variable.
890-
* `binding` for specifying the binding number for the corresponding
891-
resource variable.
892-
* `storage_class` for specifying the storage class for the corresponding
893-
resource variable.
886+
* `spv.interface_var_abi` is attribute that should be attached to each operand
887+
and result of the entry function. It should be of `#spv.interface_var_abi`
888+
attribute kind, which is defined as:
889+
890+
```
891+
spv-storage-class ::= `StorageBuffer` | ...
892+
spv-descriptor-set ::= integer-literal
893+
spv-binding ::= integer-literal
894+
spv-interface-var-abi ::= `#` `spv.interface_var_abi` `<(` spv-descriptor-set
895+
`,` spv-binding `)` (`,` spv-storage-class)? `>`
896+
```
897+
898+
For example,
899+
900+
```
901+
#spv.interface_var_abi<(0, 0), StorageBuffer>
902+
#spv.interface_var_abi<(0, 1)>
903+
```
904+
905+
The attribute has a few fields:
906+
907+
* Descriptor set number for the corresponding resource variable.
908+
* Binding number for the corresponding resource variable.
909+
* Storage class for the corresponding resource variable.
894910

895911
The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for
896912
consuming these attributes and create SPIR-V module complying with the

mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H
1414
#define MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H
1515

16+
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
1617
#include "mlir/IR/Attributes.h"
1718
#include "mlir/Support/LLVM.h"
1819

@@ -26,18 +27,57 @@ enum class Extension;
2627
enum class Version : uint32_t;
2728

2829
namespace detail {
30+
struct InterfaceVarABIAttributeStorage;
2931
struct TargetEnvAttributeStorage;
3032
struct VerCapExtAttributeStorage;
3133
} // namespace detail
3234

3335
/// SPIR-V dialect-specific attribute kinds.
3436
namespace AttrKind {
3537
enum Kind {
36-
TargetEnv = Attribute::FIRST_SPIRV_ATTR, /// Target environment
38+
InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI
39+
TargetEnv, /// Target environment
3740
VerCapExt, /// (version, extension, capability) triple
3841
};
3942
} // namespace AttrKind
4043

44+
/// An attribute that specifies the information regarding the interface
45+
/// variable: descriptor set, binding, storage class.
46+
class InterfaceVarABIAttr
47+
: public Attribute::AttrBase<InterfaceVarABIAttr, Attribute,
48+
detail::InterfaceVarABIAttributeStorage> {
49+
public:
50+
using Base::Base;
51+
52+
/// Gets a InterfaceVarABIAttr.
53+
static InterfaceVarABIAttr get(uint32_t descirptorSet, uint32_t binding,
54+
Optional<StorageClass> storageClass,
55+
MLIRContext *context);
56+
static InterfaceVarABIAttr get(IntegerAttr descriptorSet, IntegerAttr binding,
57+
IntegerAttr storageClass);
58+
59+
/// Returns the attribute kind's name (without the 'spv.' prefix).
60+
static StringRef getKindName();
61+
62+
/// Returns descriptor set.
63+
uint32_t getDescriptorSet();
64+
65+
/// Returns binding.
66+
uint32_t getBinding();
67+
68+
/// Returns `spirv::StorageClass`.
69+
Optional<StorageClass> getStorageClass();
70+
71+
static bool kindof(unsigned kind) {
72+
return kind == AttrKind::InterfaceVarABI;
73+
}
74+
75+
static LogicalResult verifyConstructionInvariants(Location loc,
76+
IntegerAttr descriptorSet,
77+
IntegerAttr binding,
78+
IntegerAttr storageClass);
79+
};
80+
4181
/// An attribute that specifies the SPIR-V (version, capabilities, extensions)
4282
/// triple.
4383
class VerCapExtAttr

mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,6 @@
2323

2424
include "mlir/Dialect/SPIRV/SPIRVBase.td"
2525

26-
// For arguments that eventually map to spv.globalVariable for the
27-
// shader interface, this attribute specifies the information regarding
28-
// the global variable:
29-
// 1) Descriptor Set.
30-
// 2) Binding number.
31-
// 3) Storage class.
32-
def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [
33-
StructFieldAttr<"descriptor_set", I32Attr>,
34-
StructFieldAttr<"binding", I32Attr>,
35-
StructFieldAttr<"storage_class", OptionalAttr<SPV_StorageClassAttr>>
36-
]>;
37-
3826
// For entry functions, this attribute specifies information related to entry
3927
// points in the generated SPIR-V module:
4028
// 1) WorkGroup Size.

mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,32 @@ namespace mlir {
2525

2626
namespace spirv {
2727
namespace detail {
28+
29+
struct InterfaceVarABIAttributeStorage : public AttributeStorage {
30+
using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
31+
32+
InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
33+
Attribute storageClass)
34+
: descriptorSet(descriptorSet), binding(binding),
35+
storageClass(storageClass) {}
36+
37+
bool operator==(const KeyTy &key) const {
38+
return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
39+
std::get<2>(key) == storageClass;
40+
}
41+
42+
static InterfaceVarABIAttributeStorage *
43+
construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
44+
return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
45+
InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
46+
std::get<2>(key));
47+
}
48+
49+
Attribute descriptorSet;
50+
Attribute binding;
51+
Attribute storageClass;
52+
};
53+
2854
struct VerCapExtAttributeStorage : public AttributeStorage {
2955
using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
3056

@@ -72,6 +98,74 @@ struct TargetEnvAttributeStorage : public AttributeStorage {
7298
} // namespace spirv
7399
} // namespace mlir
74100

101+
//===----------------------------------------------------------------------===//
102+
// InterfaceVarABIAttr
103+
//===----------------------------------------------------------------------===//
104+
105+
spirv::InterfaceVarABIAttr
106+
spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
107+
Optional<spirv::StorageClass> storageClass,
108+
MLIRContext *context) {
109+
Builder b(context);
110+
auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
111+
auto bindingAttr = b.getI32IntegerAttr(binding);
112+
auto storageClassAttr =
113+
storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
114+
: IntegerAttr();
115+
return get(descriptorSetAttr, bindingAttr, storageClassAttr);
116+
}
117+
118+
spirv::InterfaceVarABIAttr
119+
spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
120+
IntegerAttr storageClass) {
121+
assert(descriptorSet && binding);
122+
MLIRContext *context = descriptorSet.getContext();
123+
return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet,
124+
binding, storageClass);
125+
}
126+
127+
StringRef spirv::InterfaceVarABIAttr::getKindName() {
128+
return "interface_var_abi";
129+
}
130+
131+
uint32_t spirv::InterfaceVarABIAttr::getBinding() {
132+
return getImpl()->binding.cast<IntegerAttr>().getInt();
133+
}
134+
135+
uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
136+
return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
137+
}
138+
139+
Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
140+
if (getImpl()->storageClass)
141+
return static_cast<spirv::StorageClass>(
142+
getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
143+
return llvm::None;
144+
}
145+
146+
LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants(
147+
Location loc, IntegerAttr descriptorSet, IntegerAttr binding,
148+
IntegerAttr storageClass) {
149+
if (!descriptorSet.getType().isSignlessInteger(32))
150+
return emitError(loc, "expected 32-bit integer for descriptor set");
151+
152+
if (!binding.getType().isSignlessInteger(32))
153+
return emitError(loc, "expected 32-bit integer for binding");
154+
155+
if (storageClass) {
156+
if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
157+
auto storageClassValue =
158+
spirv::symbolizeStorageClass(storageClassAttr.getInt());
159+
if (!storageClassValue)
160+
return emitError(loc, "unknown storage class");
161+
} else {
162+
return emitError(loc, "expected valid storage class");
163+
}
164+
}
165+
166+
return success();
167+
}
168+
75169
//===----------------------------------------------------------------------===//
76170
// VerCapExtAttr
77171
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
118118
: Dialect(getDialectNamespace(), context) {
119119
addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
120120

121-
addAttributes<TargetEnvAttr, VerCapExtAttr>();
121+
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
122122

123123
// Add SPIR-V ops.
124124
addOperations<
@@ -649,6 +649,75 @@ static ParseResult parseKeywordList(
649649
return success();
650650
}
651651

652+
/// Parses a spirv::InterfaceVarABIAttr.
653+
static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
654+
if (parser.parseLess())
655+
return {};
656+
657+
Builder &builder = parser.getBuilder();
658+
659+
if (parser.parseLParen())
660+
return {};
661+
662+
IntegerAttr descriptorSetAttr;
663+
{
664+
auto loc = parser.getCurrentLocation();
665+
uint32_t descriptorSet = 0;
666+
auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
667+
668+
if (!descriptorSetParseResult.hasValue() ||
669+
failed(*descriptorSetParseResult)) {
670+
parser.emitError(loc, "missing descriptor set");
671+
return {};
672+
}
673+
descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
674+
}
675+
676+
if (parser.parseComma())
677+
return {};
678+
679+
IntegerAttr bindingAttr;
680+
{
681+
auto loc = parser.getCurrentLocation();
682+
uint32_t binding = 0;
683+
auto bindingParseResult = parser.parseOptionalInteger(binding);
684+
685+
if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
686+
parser.emitError(loc, "missing binding");
687+
return {};
688+
}
689+
bindingAttr = builder.getI32IntegerAttr(binding);
690+
}
691+
692+
if (parser.parseRParen())
693+
return {};
694+
695+
IntegerAttr storageClassAttr;
696+
{
697+
if (succeeded(parser.parseOptionalComma())) {
698+
auto loc = parser.getCurrentLocation();
699+
StringRef storageClass;
700+
if (parser.parseKeyword(&storageClass))
701+
return {};
702+
703+
if (auto storageClassSymbol =
704+
spirv::symbolizeStorageClass(storageClass)) {
705+
storageClassAttr = builder.getI32IntegerAttr(
706+
static_cast<uint32_t>(*storageClassSymbol));
707+
} else {
708+
parser.emitError(loc, "unknown storage class: ") << storageClass;
709+
return {};
710+
}
711+
}
712+
}
713+
714+
if (parser.parseGreater())
715+
return {};
716+
717+
return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
718+
storageClassAttr);
719+
}
720+
652721
static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
653722
if (parser.parseLess())
654723
return {};
@@ -771,6 +840,8 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
771840
return parseTargetEnvAttr(parser);
772841
if (attrKind == spirv::VerCapExtAttr::getKindName())
773842
return parseVerCapExtAttr(parser);
843+
if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
844+
return parseInterfaceVarABIAttr(parser);
774845

775846
parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
776847
<< attrKind;
@@ -801,12 +872,25 @@ static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
801872
printer << ", " << targetEnv.getResourceLimits() << ">";
802873
}
803874

875+
static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
876+
DialectAsmPrinter &printer) {
877+
printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
878+
<< interfaceVarABIAttr.getDescriptorSet() << ", "
879+
<< interfaceVarABIAttr.getBinding() << ")";
880+
auto storageClass = interfaceVarABIAttr.getStorageClass();
881+
if (storageClass)
882+
printer << ", " << spirv::stringifyStorageClass(*storageClass);
883+
printer << ">";
884+
}
885+
804886
void SPIRVDialect::printAttribute(Attribute attr,
805887
DialectAsmPrinter &printer) const {
806888
if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
807889
print(targetEnv, printer);
808890
else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
809891
print(vceAttr, printer);
892+
else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
893+
print(interfaceVarABIAttr, printer);
810894
else
811895
llvm_unreachable("unhandled SPIR-V attribute kind");
812896
}
@@ -866,11 +950,9 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
866950
auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
867951
if (!varABIAttr)
868952
return emitError(loc, "'")
869-
<< symbol
870-
<< "' attribute must be a dictionary attribute containing two or "
871-
"three 32-bit integer attributes: 'descriptor_set', 'binding', "
872-
"and optional 'storage_class'";
873-
if (varABIAttr.storage_class() && !valueType.isIntOrIndexOrFloat())
953+
<< symbol << "' must be a spirv::InterfaceVarABIAttr";
954+
955+
if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
874956
return emitError(loc, "'") << symbol
875957
<< "' attribute cannot specify storage class "
876958
"when attaching to a non-scalar value";

mlir/lib/Dialect/SPIRV/TargetAndABI.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,8 @@ spirv::InterfaceVarABIAttr
8686
spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
8787
Optional<spirv::StorageClass> storageClass,
8888
MLIRContext *context) {
89-
Type i32Type = IntegerType::get(32, context);
90-
auto scAttr =
91-
storageClass
92-
? IntegerAttr::get(i32Type, static_cast<int64_t>(*storageClass))
93-
: IntegerAttr();
94-
return spirv::InterfaceVarABIAttr::get(
95-
IntegerAttr::get(i32Type, descriptorSet),
96-
IntegerAttr::get(i32Type, binding), scAttr, context);
89+
return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
90+
context);
9791
}
9892

9993
StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }

0 commit comments

Comments
 (0)