Skip to content

Commit 1e61b37

Browse files
authored
[mlir][vector] Tighten the semantics of vector.gather (llvm#135749)
This patch restricts `vector.gather` to only accept tensors and memrefs as valid sources. Currently, the source is typed as `AnyShaped`, which also includes vectors—allowing the following (invalid) construct to pass verification: ```mlir %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` (Note: the source %base here is a vector, which is incorrect.) In contrast, `vector.scatter` currently only accepts memrefs, so some asymmetry remains between the two ops. This PR is a step toward aligning their semantics.
1 parent 41c97af commit 1e61b37

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1972,7 +1972,7 @@ def Vector_GatherOp :
19721972
DeclareOpInterfaceMethods<MaskableOpInterface>,
19731973
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
19741974
]>,
1975-
Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
1975+
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
19761976
Variadic<Index>:$indices,
19771977
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
19781978
VectorOfNonZeroRankOf<[I1]>:$mask,

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
6363
// Whether a type is a MemRefType.
6464
def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
6565

66+
// Whether a type is a TensorType or a MemRefType.
67+
def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
68+
6669
// Whether a type is an UnrankedMemRefType
6770
def IsUnrankedMemRefTypePred
6871
: CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
426429
ShapedContainerType<allowedTypes, HasValueSemanticsPred,
427430
"container with value semantics">;
428431

432+
//===----------------------------------------------------------------------===//
429433
// Vector types.
434+
//===----------------------------------------------------------------------===//
430435

431436
class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
432437
ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
755760
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
756761

757762
//===----------------------------------------------------------------------===//
758-
// Memref type.
763+
// Memref types.
759764
//===----------------------------------------------------------------------===//
760765

761766
// Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -878,6 +883,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
878883
"getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
879884
"nested tuple">;
880885

886+
//===----------------------------------------------------------------------===//
887+
// Mixed types
888+
//===----------------------------------------------------------------------===//
889+
890+
class TensorOrMemRef<list<Type> allowedTypes> :
891+
ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
892+
"::mlir::ShapedType">;
893+
881894
//===----------------------------------------------------------------------===//
882895
// Common type constraints
883896
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
14091409

14101410
// -----
14111411

1412+
func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
1413+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
1414+
%c0 = arith.constant 0 : index
1415+
// expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
1416+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1417+
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1418+
}
1419+
1420+
// -----
1421+
14121422
func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
14131423
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
14141424
%c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
14691479

14701480
// -----
14711481

1482+
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
1483+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
1484+
%c0 = arith.constant 0 : index
1485+
// expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
1486+
vector.scatter %base[%c0][%indices], %mask, %pass_thru
1487+
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1488+
}
1489+
1490+
// -----
1491+
1492+
14721493
func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
14731494
%mask: vector<16xi1>, %value: vector<16xf32>) {
14741495
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)