Skip to content

Commit ed8222b

Browse files
committed
[mlir] [VectorOps] Implement vector tuple get folding
Summary: Rewrites get-i tup<a1,...,an> into ai Reviewers: nicolasvasilache, rriddle, andydavis1 Reviewed By: nicolasvasilache, rriddle, andydavis1 Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73213
1 parent 4ed7355 commit ed8222b

File tree

3 files changed

+22
-0
lines changed

3 files changed

+22
-0
lines changed

mlir/include/mlir/Dialect/VectorOps/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,7 @@ def Vector_TupleGetOp :
11151115
}
11161116
static StringRef getIndexAttrName() { return "index"; }
11171117
}];
1118+
let hasFolder = 1;
11181119
}
11191120

11201121
def Vector_PrintOp :

mlir/lib/Dialect/VectorOps/VectorOps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,6 +1681,18 @@ static LogicalResult verify(TupleGetOp op) {
16811681
return success();
16821682
}
16831683

1684+
OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
1685+
// Rewrite:
1686+
// %t = vector.tuple .., %e_i, ..
1687+
// %x = vector.tuple_get %t, i
1688+
// into:
1689+
// %t = vector.tuple .., %e_i, .. // one less use
1690+
// %x = %e_i
1691+
if (auto tupleOp = dyn_cast_or_null<TupleOp>(getOperand().getDefiningOp()))
1692+
return tupleOp.getOperand(getIndex());
1693+
return {};
1694+
}
1695+
16841696
//===----------------------------------------------------------------------===//
16851697
// ConstantMaskOp
16861698
//===----------------------------------------------------------------------===//

mlir/test/Dialect/VectorOps/vector-transforms.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,12 @@ func @vector_transfers(%arg0: index, %arg1: index) {
302302
}
303303
return
304304
}
305+
306+
// CHECK-LABEL: func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>)
307+
// CHECK: return %arg1
308+
309+
func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
310+
%0 = vector.tuple %arg0, %arg1 : vector<4xf32>, vector<8xf32>
311+
%1 = vector.tuple_get %0, 1 : tuple<vector<4xf32>, vector<8xf32>>
312+
return %1 : vector<8xf32>
313+
}

0 commit comments

Comments
 (0)