Skip to content

Commit f596394

Browse files
committed
Add arm_neon.sdot operation
Differential Revision: https://reviews.llvm.org/D98198
1 parent cfc256b commit f596394

File tree

3 files changed

+67
-10
lines changed

3 files changed

+67
-10
lines changed

mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td

+35-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ArmNeon_IntrOp<string mnemonic, list<int> overloadedResults,
3939
list<int> overloadedOperands, int numResults,
4040
list<OpTrait> traits = [], bit requiresAccessGroup = 0>
4141
: LLVM_IntrOpBase</*dialect=*/ArmNeon_Dialect,
42-
/*opName=*/mnemonic,
42+
/*opName=*/"intr." # mnemonic,
4343
/*enumName=*/"aarch64_neon_" # !subst(".", "_", mnemonic),
4444
/*overloadedResults=*/overloadedResults,
4545
/*overloadedOperands=*/overloadedOperands,
@@ -53,6 +53,13 @@ class ArmNeon_OverloadedOneResultIntrOp<string mnemonic,
5353
list<OpTrait> traits = []>
5454
: ArmNeon_IntrOp<mnemonic, [0], [], 1, traits>;
5555

56+
// ArmNeon dialect op that corresponds to an LLVM IR intrinsic with one
57+
// overloaded result and overloaded operands list.
58+
class ArmNeon_OverloadedOperandsWithOneResultIntrOp<string mnemonic,
59+
list<int> overloadedOperands,
60+
list<OpTrait> traits = []>
61+
: ArmNeon_IntrOp<mnemonic, [0], overloadedOperands, 1, traits>;
62+
5663
def SMullOp : ArmNeon_OverloadedOneResultIntrOp<"smull", [
5764
NoSideEffect,
5865
AllTypesMatch<["a", "b"]>,
@@ -82,5 +89,32 @@ def SMullOp : ArmNeon_OverloadedOneResultIntrOp<"smull", [
8289
"$a `,` $b attr-dict `:` type($a) `to` type($res)";
8390
}
8491

92+
def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
93+
NoSideEffect,
94+
AllTypesMatch<["b", "c"]>,
95+
AllTypesMatch<["a", "res"]>,
96+
TypesMatchWith<"res has the same number of elements as operand b",
97+
"b", "res",
98+
"VectorType::get({$_self.cast<VectorType>().getShape()[0] / 4},"
99+
"IntegerType::get($_self.getContext(), 32))">]> {
100+
let summary = "sdot op";
101+
let description = [{
102+
Signed integer addition of dot product (vector). This instruction performs
103+
the following operation on signed integer vectors: res = dot(b, c) + a,
104+
where vector operands are partitioned into groups of four elements.
105+
106+
Source:
107+
https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics
108+
}];
109+
// Supports either:
110+
// (vector<2xi32>, vector<8xi8>, vector<8xi8>) -> vector<2xi32>
111+
// (vector<4xi32>, vector<16xi8>, vector<16xi8>) -> vector<16xi32>
112+
let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a,
113+
VectorOfLengthAndType<[16, 8], [I8]>:$b,
114+
VectorOfLengthAndType<[16, 8], [I8]>:$c);
115+
let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res);
116+
let assemblyFormat =
117+
"$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
118+
}
85119

86120
#endif // ARMNEON_OPS

mlir/test/Dialect/ArmNeon/roundtrip.mlir

+13-6
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
// CHECK-LABEL: arm_neon_smull
44
func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
55
-> (vector<8xi16>, vector<4xi32>, vector<2xi64>) {
6-
// CHECK: arm_neon.smull {{.*}}: vector<8xi8> to vector<8xi16>
7-
%0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16>
6+
// CHECK: arm_neon.intr.smull {{.*}}: vector<8xi8> to vector<8xi16>
7+
%0 = arm_neon.intr.smull %a, %b : vector<8xi8> to vector<8xi16>
88
%00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}:
99
vector<8xi16> to vector<4xi16>
1010

11-
// CHECK: arm_neon.smull {{.*}}: vector<4xi16> to vector<4xi32>
12-
%1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32>
11+
// CHECK: arm_neon.intr.smull {{.*}}: vector<4xi16> to vector<4xi32>
12+
%1 = arm_neon.intr.smull %00, %00 : vector<4xi16> to vector<4xi32>
1313
%11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}:
1414
vector<4xi32> to vector<2xi32>
1515

16-
// CHECK: arm_neon.smull {{.*}}: vector<2xi32> to vector<2xi64>
17-
%2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64>
16+
// CHECK: arm_neon.intr.smull {{.*}}: vector<2xi32> to vector<2xi64>
17+
%2 = arm_neon.intr.smull %11, %11 : vector<2xi32> to vector<2xi64>
1818

1919
return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
2020
}
21+
22+
// CHECK-LABEL: arm_neon_sdot
23+
func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
24+
// CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32>
25+
%0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
26+
return %0 : vector<2xi32>
27+
}

mlir/test/Target/LLVMIR/arm-neon.mlir

+19-3
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
55
// CHECK: %[[V0:.*]] = call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %{{.*}}, <8 x i8> %{{.*}})
66
// CHECK-NEXT: %[[V00:.*]] = shufflevector <8 x i16> %3, <8 x i16> %[[V0]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
7-
%0 = arm_neon.smull %arg0, %arg1 : vector<8xi8> to vector<8xi16>
7+
%0 = arm_neon.intr.smull %arg0, %arg1 : vector<8xi8> to vector<8xi16>
88
%1 = llvm.shufflevector %0, %0 [3, 4, 5, 6] : vector<8xi16>, vector<8xi16>
99

1010
// CHECK-NEXT: %[[V1:.*]] = call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %[[V00]], <4 x i16> %[[V00]])
1111
// CHECK-NEXT: %[[V11:.*]] = shufflevector <4 x i32> %[[V1]], <4 x i32> %[[V1]], <2 x i32> <i32 1, i32 2>
12-
%2 = arm_neon.smull %1, %1 : vector<4xi16> to vector<4xi32>
12+
%2 = arm_neon.intr.smull %1, %1 : vector<4xi16> to vector<4xi32>
1313
%3 = llvm.shufflevector %2, %2 [1, 2] : vector<4xi32>, vector<4xi32>
1414

1515
// CHECK-NEXT: %[[V1:.*]] = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %[[V11]], <2 x i32> %[[V11]])
16-
%4 = arm_neon.smull %3, %3 : vector<2xi32> to vector<2xi64>
16+
%4 = arm_neon.intr.smull %3, %3 : vector<2xi32> to vector<2xi64>
1717

1818
%5 = llvm.mlir.undef : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
1919
%6 = llvm.insertvalue %0, %5[0] : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
@@ -23,3 +23,19 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.str
2323
// CHECK: ret { <8 x i16>, <4 x i32>, <2 x i64> }
2424
llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
2525
}
26+
27+
// CHECK-LABEL: arm_neon_sdot_i8i8
28+
llvm.func @arm_neon_sdot_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
29+
// CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}})
30+
// CHECK-NEXT: ret <2 x i32>
31+
%0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
32+
llvm.return %0 : vector<2xi32>
33+
}
34+
35+
// CHECK-LABEL: arm_neon_sdot_i16i16
36+
llvm.func @arm_neon_sdot_i16i16(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
37+
// CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})
38+
// CHECK-NEXT: ret <4 x i32>
39+
%0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
40+
llvm.return %0 : vector<4xi32>
41+
}

0 commit comments

Comments
 (0)