Skip to content

Commit 1276ce9

Browse files
committed
Revert "[mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (#104783)"
This reverts commit 0348373 and 99c8557, which is a fix-up on top of the former. I'm reverting because this commit broke two tests: mlir/test/python/integration/dialects/linalg/opsrun.py mlir/test/python/integration/dialects/transform.py See https://lab.llvm.org/buildbot/#/builders/138/builds/4872 I'm not familiar with the tests, so I'm leaving it to the original author to either remove or adapt the broken tests, as discussed here: #104783 (comment)
1 parent 72f339d commit 1276ce9

File tree

14 files changed

+182
-943
lines changed

14 files changed

+182
-943
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -684,16 +684,6 @@ def LinalgStructuredInterface
684684
return;
685685
}]
686686
>,
687-
InterfaceMethod<
688-
/*desc=*/[{
689-
Return true if the user has supplied an explicit indexing maps for this op.
690-
}],
691-
/*retTy=*/"bool",
692-
/*methodName=*/"hasUserDefinedMaps",
693-
/*args=*/(ins),
694-
/*methodBody=*/"",
695-
/*defaultImplementation=*/[{ return false; }]
696-
>,
697687
//===------------------------------------------------------------------===//
698688
// Linalg generalization hooks.
699689
//===------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,78 @@ structured_op: !LinalgStructuredOpConfig
10651065
- !ScalarExpression
10661066
scalar_arg: rhs
10671067
--- !LinalgOpConfig
1068+
metadata: !LinalgOpMetadata
1069+
name: matmul
1070+
cpp_class_name: MatmulOp
1071+
doc: |-
1072+
Performs a matrix multiplication of two 2D inputs.
1073+
1074+
Numeric casting is performed on the operands to the inner multiply, promoting
1075+
them to the same data type as the accumulator/output.
1076+
implements:
1077+
- LinalgContractionOpInterface
1078+
structured_op: !LinalgStructuredOpConfig
1079+
args:
1080+
- !LinalgOperandDefConfig
1081+
name: A
1082+
kind: input_tensor
1083+
type_var: T1
1084+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
1085+
- !LinalgOperandDefConfig
1086+
name: B
1087+
kind: input_tensor
1088+
type_var: T2
1089+
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
1090+
- !LinalgOperandDefConfig
1091+
name: C
1092+
kind: output_tensor
1093+
type_var: U
1094+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
1095+
- !LinalgOperandDefConfig
1096+
name: cast
1097+
kind: type_fn_attr
1098+
default_fn: cast_signed
1099+
indexing_maps: !LinalgIndexingMapsConfig
1100+
static_indexing_maps:
1101+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
1102+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
1103+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
1104+
iterator_types:
1105+
- parallel
1106+
- parallel
1107+
- reduction
1108+
assignments:
1109+
- !ScalarAssign
1110+
arg: C
1111+
value: !ScalarExpression
1112+
scalar_fn:
1113+
kind: binary
1114+
fn_name: add
1115+
operands:
1116+
- !ScalarExpression
1117+
scalar_arg: C
1118+
- !ScalarExpression
1119+
scalar_fn:
1120+
kind: binary
1121+
fn_name: mul
1122+
operands:
1123+
- !ScalarExpression
1124+
scalar_fn:
1125+
kind: type
1126+
attr_name: cast
1127+
type_var: U
1128+
operands:
1129+
- !ScalarExpression
1130+
scalar_arg: A
1131+
- !ScalarExpression
1132+
scalar_fn:
1133+
kind: type
1134+
attr_name: cast
1135+
type_var: U
1136+
operands:
1137+
- !ScalarExpression
1138+
scalar_arg: B
1139+
--- !LinalgOpConfig
10681140
metadata: !LinalgOpMetadata
10691141
name: quantized_matmul
10701142
cpp_class_name: QuantizedMatmulOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 0 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -535,140 +535,6 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
535535
let hasCanonicalizer = 1;
536536
}
537537

538-
//===----------------------------------------------------------------------===//
539-
// Op definition for MatmulOp
540-
//===----------------------------------------------------------------------===//
541-
542-
def MatmulOp : LinalgStructuredBase_Op<"matmul", [
543-
AttrSizedOperandSegments,
544-
LinalgContractionOpInterface]> {
545-
546-
let summary = [{
547-
Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
548-
}];
549-
let description = [{
550-
Numeric casting is performed on the operands to the inner multiply,
551-
promoting them to the same data type as the accumulator/output.
552-
553-
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
554-
'indexing_maps' as shown below.This is a list attribute, so the list must include all
555-
the maps if specified.
556-
557-
Example Transpose:
558-
```
559-
linalg.matmul indexing_maps = [
560-
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
561-
affine_map<(d0, d1, d2) -> (d2, d1)>,
562-
affine_map<(d0, d1, d2) -> (d0, d1)>
563-
]
564-
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
565-
outs(%arg2: memref<3x7xf32>)
566-
```
567-
568-
Example Broadcast:
569-
```
570-
linalg.matmul indexing_maps = [
571-
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
572-
affine_map<(d0, d1, d2) -> (d2, d1)>,
573-
affine_map<(d0, d1, d2) -> (d0, d1)>
574-
]
575-
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
576-
outs(%arg2: memref<3x7xf32>)
577-
```
578-
579-
Example Broadcast and transpose:
580-
```
581-
linalg.matmul indexing_maps = [
582-
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
583-
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
584-
affine_map<(d0, d1, d2) -> (d0, d1)>
585-
]
586-
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
587-
}];
588-
589-
let arguments = (ins
590-
Variadic<AnyType>:$inputs,
591-
Variadic<AnyShaped>:$outputs,
592-
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
593-
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
594-
);
595-
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
596-
let regions = (region AnyRegion:$region);
597-
598-
let skipDefaultBuilders = 1;
599-
let builders = [
600-
OpBuilder<
601-
(ins "ValueRange":$inputs, "ValueRange":$outputs,
602-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
603-
[{
604-
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
605-
attributes, MatmulOp::getRegionBuilder());
606-
}]>,
607-
OpBuilder<
608-
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
609-
"ValueRange":$outputs,
610-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
611-
[{
612-
buildStructuredOp($_builder, $_state, resultTensorTypes,
613-
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
614-
}]>,
615-
OpBuilder<
616-
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
617-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
618-
[{
619-
$_state.addOperands(operands);
620-
$_state.addAttributes(attributes);
621-
$_state.addTypes(resultTensorTypes);
622-
(void)$_state.addRegion();
623-
}]>,
624-
OpBuilder<
625-
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
626-
"ValueRange":$outputs,
627-
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
628-
[{
629-
$_state.addAttribute("cast", cast);
630-
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
631-
attributes, MatmulOp::getRegionBuilder());
632-
}]>
633-
634-
];
635-
let hasCustomAssemblyFormat = 1;
636-
let hasFolder = 1;
637-
let hasVerifier = 1;
638-
639-
let extraClassDeclaration = structuredOpsBaseDecls # [{
640-
SmallVector<utils::IteratorType> getIteratorTypesArray();
641-
642-
/// Implements the block region builder.
643-
static void regionBuilder(ImplicitLocOpBuilder &b,
644-
Block &block, ArrayRef<NamedAttribute> attrs);
645-
646-
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
647-
SmallVector<AffineMap> getDefaultIndexingMaps();
648-
649-
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
650-
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
651-
652-
static std::function<void(ImplicitLocOpBuilder &,
653-
Block &, ArrayRef<NamedAttribute>)>
654-
getRegionBuilder() {
655-
return regionBuilder;
656-
}
657-
658-
::mlir::MutableOperandRange getDpsInitsMutable() {
659-
return getOutputsMutable();
660-
}
661-
662-
// Generic methods.
663-
static unsigned getNumRegionArgs();
664-
std::string getLibraryCallName();
665-
bool hasDynamicIndexingMaps();
666-
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
667-
/// user defined indexing maps are not equal to default map.
668-
bool hasUserDefinedMaps();
669-
}];
670-
}
671-
672538
//===----------------------------------------------------------------------===//
673539
// Named Linalg ops, implemented as a declarative configurations of generic ops.
674540
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,13 @@
1515
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18-
#include "mlir/IR/AffineExpr.h"
1918
#include "mlir/IR/AffineExprVisitor.h"
2019
#include "mlir/IR/AffineMap.h"
21-
#include "mlir/IR/BuiltinTypeInterfaces.h"
22-
#include "mlir/IR/MLIRContext.h"
2320
#include "mlir/IR/TypeUtilities.h"
24-
#include "llvm/ADT/STLExtras.h"
2521
#include "llvm/ADT/SetOperations.h"
2622
#include "llvm/ADT/SmallBitVector.h"
2723
#include "llvm/ADT/SmallVector.h"
28-
#include "llvm/Support/Casting.h"
29-
#include "llvm/Support/raw_ostream.h"
3024
#include <algorithm>
31-
#include <optional>
3225

3326
using namespace mlir;
3427
using namespace mlir::linalg;
@@ -1149,6 +1142,7 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
11491142

11501143
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
11511144
LinalgOp linalgOp = cast<LinalgOp>(op);
1145+
11521146
// Mixed tensor/buffer operands are not allowed.
11531147
if (!linalgOp.hasPureTensorSemantics() &&
11541148
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1168,8 +1162,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
11681162
<< ") to be equal to the number of input/output operands ("
11691163
<< linalgOp->getNumOperands() << ")";
11701164

1171-
// Set this flag if this op has user defined maps. This is required to guard
1172-
// the below error condition which assume default indexing maps.
11731165
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
11741166
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
11751167

@@ -1186,13 +1178,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
11861178
<< " dim(s) to match the number of loops";
11871179

11881180
int64_t rank = linalgOp.getRank(&opOperand);
1189-
11901181
if (indexingMap.getNumResults() != rank)
11911182
return op->emitOpError("expected operand rank (")
11921183
<< rank << ") to match the result rank of indexing_map #"
11931184
<< opOperand.getOperandNumber() << " ("
11941185
<< indexingMap.getNumResults() << ")";
11951186
}
1187+
11961188
SmallVector<unsigned> redDims;
11971189
linalgOp.getReductionDims(redDims);
11981190

@@ -1202,8 +1194,9 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
12021194
// Check if given shapes match to inferred shapes.
12031195
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
12041196
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1205-
// Verify only static cases since we can't get exact dimension sizes and
1206-
// loop ranges for dynamic cases in this stage.
1197+
1198+
// Verify only static cases since we can't get exact dimension sizes and loop
1199+
// ranges for dynamic cases in this stage.
12071200
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
12081201
for (int64_t &range : endLoopRangeValues)
12091202
range -= 1;

0 commit comments

Comments
 (0)