Skip to content

Commit a9cb529

Browse files
committed
[mlir][spirv] NFC: use Optional to replace SPV_Optional
Differential Revision: https://reviews.llvm.org/D78046
1 parent 359541e commit a9cb529

File tree

8 files changed

+12
-17
lines changed

8 files changed

+12
-17
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,9 +3040,6 @@ def SPV_IntVec4 : SPV_Vec4<SPV_Integer>;
30403040
def SPV_IOrUIVec4 : SPV_Vec4<SPV_SignlessOrUnsignedInt>;
30413041
def SPV_Int32Vec4 : SPV_Vec4<AnyI32>;
30423042

3043-
// TODO(antiagainst): Use a more appropriate way to model optional operands
3044-
class SPV_Optional<Type type> : Variadic<type>;
3045-
30463043
// TODO(ravishankarm): From 1.4, this should also include Composite type.
30473044
def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
30483045

mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [
240240
);
241241

242242
let results = (outs
243-
SPV_Optional<SPV_Type>:$result
243+
Optional<SPV_Type>:$result
244244
);
245245

246246
let autogenSerialization = 0;

mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class SPV_GroupNonUniformArithmeticOp<string mnemonic, Type type,
3030
SPV_ScopeAttr:$execution_scope,
3131
SPV_GroupOperationAttr:$group_operation,
3232
SPV_ScalarOrVectorOf<type>:$value,
33-
SPV_Optional<SPV_Integer>:$cluster_size
33+
Optional<SPV_Integer>:$cluster_size
3434
);
3535

3636
let results = (outs

mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def SPV_VariableOp : SPV_Op<"Variable", []> {
469469

470470
let arguments = (ins
471471
SPV_StorageClassAttr:$storage_class,
472-
SPV_Optional<AnyType>:$initializer
472+
Optional<AnyType>:$initializer
473473
);
474474

475475
let results = (outs

mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
155155
groupOperation = rewriter.create<spirv::spvOp>( \
156156
loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \
157157
spirv::GroupOperation::Reduce, inputElement, \
158-
/*cluster_size=*/ArrayRef<Value>()); \
158+
/*cluster_size=*/nullptr); \
159159
} break
160160
switch (*binaryOpKind) {
161161
CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);

mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2291,6 +2291,10 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
22912291
<< operands[0];
22922292
}
22932293

2294+
// Use null type to mean no result type.
2295+
if (isVoidType(resultType))
2296+
resultType = nullptr;
2297+
22942298
auto resultID = operands[1];
22952299
auto functionID = operands[2];
22962300

@@ -2306,18 +2310,12 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
23062310
arguments.push_back(value);
23072311
}
23082312

2309-
SmallVector<Type, 1> resultTypes;
2310-
if (!isVoidType(resultType)) {
2311-
resultTypes.push_back(resultType);
2312-
}
2313-
23142313
auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
2315-
unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName),
2314+
unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName),
23162315
arguments);
23172316

2318-
if (!resultTypes.empty()) {
2317+
if (resultType)
23192318
valueMap[resultID] = opFunctionCall.getResult(0);
2320-
}
23212319
return success();
23222320
}
23232321

mlir/test/Dialect/SPIRV/control-flow-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func @caller() {
202202

203203
spv.module Logical GLSL450 {
204204
spv.func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () "None" {
205-
// expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}}
205+
// expected-error @+1 {{result group starting at #0 requires 0 or 1 element, but found 2}}
206206
%0:2 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32)
207207
spv.Return
208208
}

mlir/utils/spirv/gen_spirv_dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def map_spec_operand_to_ods_argument(operand):
548548
if quantifier == '':
549549
arg_type = 'SPV_Type'
550550
elif quantifier == '?':
551-
arg_type = 'SPV_Optional<SPV_Type>'
551+
arg_type = 'Optional<SPV_Type>'
552552
else:
553553
arg_type = 'Variadic<SPV_Type>'
554554
elif kind == 'IdMemorySemantics' or kind == 'IdScope':

0 commit comments

Comments
 (0)