Skip to content

Commit 4c83c27

Browse files
authored
[mlir][spirv] Add folding for [I|Logical][Not]Equal (llvm#74194)
1 parent cf048e1 commit 4c83c27

File tree

4 files changed

+256
-11
lines changed

4 files changed

+256
-11
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td

+8-1
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
369369
%5 = spirv.IEqual %2, %3 : vector<4xi32>
370370
```
371371
}];
372+
373+
let hasFolder = 1;
372374
}
373375

374376
// -----
@@ -395,6 +397,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
395397

396398
```
397399
}];
400+
401+
let hasFolder = 1;
398402
}
399403

400404
// -----
@@ -501,6 +505,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
501505
%2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
502506
```
503507
}];
508+
509+
let hasFolder = 1;
504510
}
505511

506512
// -----
@@ -557,7 +563,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
557563
%2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
558564
```
559565
}];
560-
let hasFolder = true;
566+
567+
let hasFolder = 1;
561568
}
562569

563570
// -----

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

+75-2
Original file line numberDiff line numberDiff line change
@@ -662,19 +662,52 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
662662
return Attribute();
663663
}
664664

665+
//===----------------------------------------------------------------------===//
666+
// spirv.LogicalEqualOp
667+
//===----------------------------------------------------------------------===//
668+
669+
OpFoldResult
670+
spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
671+
// x == x -> true
672+
if (getOperand1() == getOperand2()) {
673+
auto trueAttr = BoolAttr::get(getContext(), true);
674+
if (isa<IntegerType>(getType()))
675+
return trueAttr;
676+
if (auto vecTy = dyn_cast<VectorType>(getType()))
677+
return SplatElementsAttr::get(vecTy, trueAttr);
678+
}
679+
680+
return constFoldBinaryOp<IntegerAttr>(
681+
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
682+
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
683+
});
684+
}
685+
665686
//===----------------------------------------------------------------------===//
666687
// spirv.LogicalNotEqualOp
667688
//===----------------------------------------------------------------------===//
668689

669690
OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
670691
if (std::optional<bool> rhs =
671692
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
672-
// x && false = x
693+
// x != false -> x
673694
if (!rhs.value())
674695
return getOperand1();
675696
}
676697

677-
return Attribute();
698+
// x == x -> false
699+
if (getOperand1() == getOperand2()) {
700+
auto falseAttr = BoolAttr::get(getContext(), false);
701+
if (isa<IntegerType>(getType()))
702+
return falseAttr;
703+
if (auto vecTy = dyn_cast<VectorType>(getType()))
704+
return SplatElementsAttr::get(vecTy, falseAttr);
705+
}
706+
707+
return constFoldBinaryOp<IntegerAttr>(
708+
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
709+
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
710+
});
678711
}
679712

680713
//===----------------------------------------------------------------------===//
@@ -709,6 +742,46 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
709742
return Attribute();
710743
}
711744

745+
//===----------------------------------------------------------------------===//
746+
// spirv.IEqualOp
747+
//===----------------------------------------------------------------------===//
748+
749+
OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
750+
// x == x -> true
751+
if (getOperand1() == getOperand2()) {
752+
auto trueAttr = BoolAttr::get(getContext(), true);
753+
if (isa<IntegerType>(getType()))
754+
return trueAttr;
755+
if (auto vecTy = dyn_cast<VectorType>(getType()))
756+
return SplatElementsAttr::get(vecTy, trueAttr);
757+
}
758+
759+
return constFoldBinaryOp<IntegerAttr>(
760+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
761+
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
762+
});
763+
}
764+
765+
//===----------------------------------------------------------------------===//
766+
// spirv.INotEqualOp
767+
//===----------------------------------------------------------------------===//
768+
769+
OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
770+
// x == x -> false
771+
if (getOperand1() == getOperand2()) {
772+
auto falseAttr = BoolAttr::get(getContext(), false);
773+
if (isa<IntegerType>(getType()))
774+
return falseAttr;
775+
if (auto vecTy = dyn_cast<VectorType>(getType()))
776+
return SplatElementsAttr::get(vecTy, falseAttr);
777+
}
778+
779+
return constFoldBinaryOp<IntegerAttr>(
780+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
781+
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
782+
});
783+
}
784+
712785
//===----------------------------------------------------------------------===//
713786
// spirv.ShiftLeftLogical
714787
//===----------------------------------------------------------------------===//

mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir

+8-8
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
// CHECK-LABEL: @logical_equal_scalar
88
spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
99
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
10-
%0 = spirv.LogicalEqual %arg0, %arg0 : i1
10+
%0 = spirv.LogicalEqual %arg0, %arg1 : i1
1111
spirv.Return
1212
}
1313

1414
// CHECK-LABEL: @logical_equal_vector
1515
spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
1616
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
17-
%0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
17+
%0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
1818
spirv.Return
1919
}
2020

@@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
2525
// CHECK-LABEL: @logical_not_equal_scalar
2626
spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
2727
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
28-
%0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
28+
%0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
2929
spirv.Return
3030
}
3131

3232
// CHECK-LABEL: @logical_not_equal_vector
3333
spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
3434
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
35-
%0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
35+
%0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
3636
spirv.Return
3737
}
3838

@@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
6363
// CHECK-LABEL: @logical_and_scalar
6464
spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
6565
// CHECK: llvm.and %{{.*}}, %{{.*}} : i1
66-
%0 = spirv.LogicalAnd %arg0, %arg0 : i1
66+
%0 = spirv.LogicalAnd %arg0, %arg1 : i1
6767
spirv.Return
6868
}
6969

7070
// CHECK-LABEL: @logical_and_vector
7171
spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
7272
// CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
73-
%0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
73+
%0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
7474
spirv.Return
7575
}
7676

@@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
8181
// CHECK-LABEL: @logical_or_scalar
8282
spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
8383
// CHECK: llvm.or %{{.*}}, %{{.*}} : i1
84-
%0 = spirv.LogicalOr %arg0, %arg0 : i1
84+
%0 = spirv.LogicalOr %arg0, %arg1 : i1
8585
spirv.Return
8686
}
8787

8888
// CHECK-LABEL: @logical_or_vector
8989
spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
9090
// CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
91-
%0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
91+
%0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
9292
spirv.Return
9393
}

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

+165
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,48 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
10481048
spirv.ReturnValue %3 : vector<3xi1>
10491049
}
10501050

1051+
// -----
1052+
1053+
//===----------------------------------------------------------------------===//
1054+
// spirv.LogicalEqual
1055+
//===----------------------------------------------------------------------===//
1056+
1057+
// CHECK-LABEL: @logical_equal_same
1058+
func.func @logical_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
1059+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1060+
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
1061+
1062+
%0 = spirv.LogicalEqual %arg0, %arg0 : i1
1063+
%1 = spirv.LogicalEqual %arg1, %arg1 : vector<3xi1>
1064+
// CHECK: return %[[CTRUE]], %[[CVTRUE]]
1065+
return %0, %1 : i1, vector<3xi1>
1066+
}
1067+
1068+
// CHECK-LABEL: @const_fold_scalar_logical_equal
1069+
func.func @const_fold_scalar_logical_equal() -> (i1, i1) {
1070+
%true = spirv.Constant true
1071+
%false = spirv.Constant false
1072+
1073+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1074+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1075+
%0 = spirv.LogicalEqual %true, %false : i1
1076+
%1 = spirv.LogicalEqual %false, %false : i1
1077+
1078+
// CHECK: return %[[CFALSE]], %[[CTRUE]]
1079+
return %0, %1 : i1, i1
1080+
}
1081+
1082+
// CHECK-LABEL: @const_fold_vector_logical_equal
1083+
func.func @const_fold_vector_logical_equal() -> vector<3xi1> {
1084+
%cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
1085+
%cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
1086+
1087+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
1088+
%0 = spirv.LogicalEqual %cv0, %cv1 : vector<3xi1>
1089+
1090+
// CHECK: return %[[RET]]
1091+
return %0 : vector<3xi1>
1092+
}
10511093

10521094
// -----
10531095

@@ -1064,6 +1106,43 @@ func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
10641106
spirv.ReturnValue %0 : vector<4xi1>
10651107
}
10661108

1109+
// CHECK-LABEL: @logical_not_equal_same
1110+
func.func @logical_not_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
1111+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1112+
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
1113+
%0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
1114+
%1 = spirv.LogicalNotEqual %arg1, %arg1 : vector<3xi1>
1115+
1116+
// CHECK: return %[[CFALSE]], %[[CVFALSE]]
1117+
return %0, %1 : i1, vector<3xi1>
1118+
}
1119+
1120+
// CHECK-LABEL: @const_fold_scalar_logical_not_equal
1121+
func.func @const_fold_scalar_logical_not_equal() -> (i1, i1) {
1122+
%true = spirv.Constant true
1123+
%false = spirv.Constant false
1124+
1125+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1126+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1127+
%0 = spirv.LogicalNotEqual %true, %false : i1
1128+
%1 = spirv.LogicalNotEqual %false, %false : i1
1129+
1130+
// CHECK: return %[[CTRUE]], %[[CFALSE]]
1131+
return %0, %1 : i1, i1
1132+
}
1133+
1134+
// CHECK-LABEL: @const_fold_vector_logical_not_equal
1135+
func.func @const_fold_vector_logical_not_equal() -> vector<3xi1> {
1136+
%cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
1137+
%cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
1138+
1139+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
1140+
%0 = spirv.LogicalNotEqual %cv0, %cv1 : vector<3xi1>
1141+
1142+
// CHECK: return %[[RET]]
1143+
return %0 : vector<3xi1>
1144+
}
1145+
10671146
// -----
10681147

10691148
func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
@@ -1139,6 +1218,92 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
11391218

11401219
// -----
11411220

1221+
//===----------------------------------------------------------------------===//
1222+
// spirv.IEqual
1223+
//===----------------------------------------------------------------------===//
1224+
1225+
// CHECK-LABEL: @iequal_same
1226+
func.func @iequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1227+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1228+
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
1229+
%0 = spirv.IEqual %arg0, %arg0 : i32
1230+
%1 = spirv.IEqual %arg1, %arg1 : vector<3xi32>
1231+
1232+
// CHECK: return %[[CTRUE]], %[[CVTRUE]]
1233+
return %0, %1 : i1, vector<3xi1>
1234+
}
1235+
1236+
// CHECK-LABEL: @const_fold_scalar_iequal
1237+
func.func @const_fold_scalar_iequal() -> (i1, i1) {
1238+
%c5 = spirv.Constant 5 : i32
1239+
%c6 = spirv.Constant 6 : i32
1240+
1241+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1242+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1243+
%0 = spirv.IEqual %c5, %c6 : i32
1244+
%1 = spirv.IEqual %c5, %c5 : i32
1245+
1246+
// CHECK: return %[[CFALSE]], %[[CTRUE]]
1247+
return %0, %1 : i1, i1
1248+
}
1249+
1250+
// CHECK-LABEL: @const_fold_vector_iequal
1251+
func.func @const_fold_vector_iequal() -> vector<3xi1> {
1252+
%cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
1253+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1254+
1255+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
1256+
%0 = spirv.IEqual %cv0, %cv1 : vector<3xi32>
1257+
1258+
// CHECK: return %[[RET]]
1259+
return %0 : vector<3xi1>
1260+
}
1261+
1262+
// -----
1263+
1264+
//===----------------------------------------------------------------------===//
1265+
// spirv.INotEqual
1266+
//===----------------------------------------------------------------------===//
1267+
1268+
// CHECK-LABEL: @inotequal_same
1269+
func.func @inotequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1270+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1271+
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
1272+
%0 = spirv.INotEqual %arg0, %arg0 : i32
1273+
%1 = spirv.INotEqual %arg1, %arg1 : vector<3xi32>
1274+
1275+
// CHECK: return %[[CFALSE]], %[[CVFALSE]]
1276+
return %0, %1 : i1, vector<3xi1>
1277+
}
1278+
1279+
// CHECK-LABEL: @const_fold_scalar_inotequal
1280+
func.func @const_fold_scalar_inotequal() -> (i1, i1) {
1281+
%c5 = spirv.Constant 5 : i32
1282+
%c6 = spirv.Constant 6 : i32
1283+
1284+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1285+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1286+
%0 = spirv.INotEqual %c5, %c6 : i32
1287+
%1 = spirv.INotEqual %c5, %c5 : i32
1288+
1289+
// CHECK: return %[[CTRUE]], %[[CFALSE]]
1290+
return %0, %1 : i1, i1
1291+
}
1292+
1293+
// CHECK-LABEL: @const_fold_vector_inotequal
1294+
func.func @const_fold_vector_inotequal() -> vector<3xi1> {
1295+
%cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
1296+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1297+
1298+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
1299+
%0 = spirv.INotEqual %cv0, %cv1 : vector<3xi32>
1300+
1301+
// CHECK: return %[[RET]]
1302+
return %0 : vector<3xi1>
1303+
}
1304+
1305+
// -----
1306+
11421307
//===----------------------------------------------------------------------===//
11431308
// spirv.LeftShiftLogical
11441309
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)