Skip to content

Commit 067bd7d

Browse files
committed
[mlir][vector] Use optional for outerproduct accumulator instead of variadic
This was introduced before the Optional directive and uses Variadic, but it's really optional. Reviewed By: nicolasvasilache, benmxwl-arm, dcaballe Differential Revision: https://reviews.llvm.org/D159259
1 parent 8ba1c38 commit 067bd7d

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ def Vector_OuterProductOp :
897897
TCresVTEtIsSameAsOpBase<0, 1>>,
898898
DeclareOpInterfaceMethods<MaskableOpInterface>]>,
899899
Arguments<(ins AnyVector:$lhs, AnyType:$rhs,
900-
Variadic<AnyVector>:$acc,
900+
Optional<AnyVector>:$acc,
901901
DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
902902
Results<(outs AnyVector)> {
903903
let summary = "vector outerproduct with optional fused add";
@@ -961,9 +961,9 @@ def Vector_OuterProductOp :
961961
return getRhs().getType();
962962
}
963963
VectorType getOperandVectorTypeACC() {
964-
return getAcc().empty()
965-
? VectorType()
966-
: ::llvm::cast<VectorType>((*getAcc().begin()).getType());
964+
return getAcc()
965+
? ::llvm::cast<VectorType>(getAcc().getType())
966+
: VectorType();
967967
}
968968
VectorType getResultVectorType() {
969969
return ::llvm::cast<VectorType>(getResult().getType());

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2756,7 +2756,7 @@ void OuterProductOp::build(OpBuilder &builder, OperationState &result,
27562756

27572757
void OuterProductOp::print(OpAsmPrinter &p) {
27582758
p << " " << getLhs() << ", " << getRhs();
2759-
if (!getAcc().empty()) {
2759+
if (getAcc()) {
27602760
p << ", " << getAcc();
27612761
p.printOptionalAttrDict((*this)->getAttrs());
27622762
}

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
11281128
VectorType resType = op.getResultVectorType();
11291129
Type eltType = resType.getElementType();
11301130
bool isInt = isa<IntegerType, IndexType>(eltType);
1131-
Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
1131+
Value acc = op.getAcc();
11321132
vector::CombiningKind kind = op.getKind();
11331133

11341134
// Vector mask setup.

0 commit comments

Comments
 (0)