Skip to content

Commit 1a2bb03

Browse files
committed
[MLIR][LINALG] Add canonicalization pattern in linalg.generic op for static shape inference.
This commit adds canonicalization pattern in `linalg.generic` op for static shape inference. If any of the inputs or outputs have static shape or is casted from a tensor of static shape, then shapes of all the inputs and outputs can be inferred by using the affine map of the static shape input/output. Signed-Off-By: Prateek Gupta <[email protected]> Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D118929
1 parent c1e4e01 commit 1a2bb03

File tree

3 files changed

+299
-10
lines changed

3 files changed

+299
-10
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,11 +841,169 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
841841
return success();
842842
}
843843
};
844+
845+
/// For each of the operand in `operands` this function maps the static sizes of
846+
/// dimensions to their affine dim expressions.
847+
static void populateMap(GenericOp genericOp, ArrayRef<OpOperand *> operands,
848+
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
849+
for (OpOperand *opOperand : operands) {
850+
if (genericOp.isScalar(opOperand))
851+
continue;
852+
Value src = opOperand->get();
853+
auto sourceType = src.getType().cast<RankedTensorType>();
854+
auto sourceMap = genericOp.getTiedIndexingMap(opOperand);
855+
856+
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
857+
// `tensor.cast` operation and source of the cast operation has a static
858+
// shape, then assign it to the `sourceShape`.
859+
auto parentOp = src.getDefiningOp();
860+
ArrayRef<int64_t> sourceShape = sourceType.getShape();
861+
if (parentOp) {
862+
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
863+
Value castSource = castOp.source();
864+
auto castSourceType = castSource.getType().cast<RankedTensorType>();
865+
if (castSourceType.hasStaticShape())
866+
sourceShape = castSourceType.getShape();
867+
}
868+
}
869+
870+
// If the source shape's dimension has a static shape, map the affine dim
871+
// expression to the known static size.
872+
for (unsigned i = 0; i < sourceShape.size(); i++) {
873+
if (sourceType.isDynamicDim(i))
874+
continue;
875+
if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
876+
affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
877+
}
878+
}
879+
}
880+
881+
/// Creates new operand w.r.t 'opOperand' of `genericOp` with static sizes
882+
/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
883+
/// their result types is stored in `resultTypes`. If `opOperand` requires no
884+
/// change then `changeNeeded` is false and same operand is added in the
885+
/// `newOperands` list.
886+
static void createNewOperandWithStaticSizes(
887+
Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
888+
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, GenericOp genericOp,
889+
SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
890+
bool &changeNeeded) {
891+
Value src = opOperand->get();
892+
newOperands.push_back(src);
893+
if (genericOp.isScalar(opOperand))
894+
return;
895+
auto sourceType = src.getType().cast<RankedTensorType>();
896+
Type resultType = sourceType;
897+
if (sourceType.hasStaticShape() && genericOp.isOutputTensor(opOperand)) {
898+
resultTypes.push_back(resultType);
899+
return;
900+
}
901+
ArrayRef<int64_t> sourceShape = sourceType.getShape();
902+
AffineMap sourceMap = genericOp.getTiedIndexingMap(opOperand);
903+
SmallVector<int64_t> newShape;
904+
// If operand is updated with new shape, `newOperandNeeded` will be
905+
// true.
906+
bool newOperandNeeded = false;
907+
for (unsigned i = 0; i < sourceShape.size(); i++) {
908+
int64_t dimShape = sourceShape[i];
909+
AffineExpr dimExpr = sourceMap.getResult(i);
910+
if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
911+
!sourceType.isDynamicDim(i)) {
912+
newShape.push_back(dimShape);
913+
continue;
914+
}
915+
// Dimension has a dynamic shape and corresponding affine dim
916+
// expression is present in the map. So assign the size for the
917+
// given affine dim expression to the dimension.
918+
newShape.push_back(affineExprToSize[dimExpr]);
919+
newOperandNeeded = true;
920+
}
921+
resultType = RankedTensorType::get(newShape, sourceType.getElementType());
922+
if (newOperandNeeded) {
923+
changeNeeded = true;
924+
// Get the new operand value given its size and element type by
925+
// casting it.
926+
Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
927+
unsigned index = opOperand->getOperandNumber();
928+
newOperands[index] = newOperand;
929+
}
930+
if (genericOp.isOutputTensor(opOperand))
931+
resultTypes.push_back(resultType);
932+
}
933+
934+
/// Static shapes for the operands can be inferred if any one of the operands
935+
/// have a static shape. This can be done by referring to the affine dim
936+
/// expressions for the operand.
937+
struct InferStaticShapeOfOperands : public OpRewritePattern<GenericOp> {
938+
using OpRewritePattern<GenericOp>::OpRewritePattern;
939+
940+
LogicalResult matchAndRewrite(GenericOp genericOp,
941+
PatternRewriter &rewriter) const override {
942+
if (!genericOp.hasTensorSemantics())
943+
return failure();
944+
945+
// Maps must be projected permutations.
946+
if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) {
947+
return !map.isProjectedPermutation();
948+
}))
949+
return failure();
950+
951+
// Maps affine dim expressions to the static size of that dimension.
952+
llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
953+
Location loc = genericOp.getLoc();
954+
955+
// For each of the affine dim expression, check if the size is known. If
956+
// known add that in the map.
957+
populateMap(genericOp, genericOp.getInputAndOutputOperands(),
958+
affineExprToSize);
959+
960+
SmallVector<Value> newOperands;
961+
SmallVector<Type> resultTypes;
962+
963+
// `changeNeeded` is `false` if the operands of `genericOp` require no
964+
// change in their types.
965+
bool changeNeeded = false;
966+
newOperands.reserve(genericOp.getNumInputsAndOutputs());
967+
resultTypes.reserve(genericOp.getNumOutputs());
968+
969+
// Iterate over all the operands and update the static sizes.
970+
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
971+
createNewOperandWithStaticSizes(loc, rewriter, opOperand,
972+
affineExprToSize, genericOp, newOperands,
973+
resultTypes, changeNeeded);
974+
}
975+
976+
// If the generic op has all the required static information, no
977+
// canonicalization needed.
978+
if (!changeNeeded)
979+
return failure();
980+
981+
// Clone op.
982+
Operation *newOp =
983+
cast<linalg::LinalgOp>(genericOp.getOperation())
984+
.clone(rewriter, genericOp->getLoc(), resultTypes, newOperands);
985+
SmallVector<Value> replacements;
986+
replacements.reserve(newOp->getNumResults());
987+
for (auto it : llvm::zip(genericOp->getResults(), newOp->getResults())) {
988+
Value newResult = std::get<1>(it);
989+
Value oldResult = std::get<0>(it);
990+
Type newType = newResult.getType();
991+
Type oldType = oldResult.getType();
992+
replacements.push_back(
993+
(newType != oldType)
994+
? rewriter.create<tensor::CastOp>(loc, newType, newResult)
995+
: newResult);
996+
}
997+
rewriter.replaceOp(genericOp, replacements);
998+
return success();
999+
}
1000+
};
8441001
} // namespace
8451002

8461003
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
8471004
MLIRContext *context) {
848-
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
1005+
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
1006+
InferStaticShapeOfOperands>(context);
8491007
}
8501008

8511009
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,133 @@ func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
650650
} : tensor<400x273xf32> to tensor<412x276xf32>
651651
return %pad : tensor<412x276xf32>
652652
}
653+
654+
// -----
655+
656+
// Tests below verify whether static information is propagated through all the operands of generic op.
657+
// 1. If one of the inputs of generic op has static info and it has no cast source.
658+
// 2. If one of the inputs of generic op has static info and it is coming from tensr.cast operation.
659+
// 3. If one of the outputs of generic op has static info and it is coming from tenso.cast operation.
660+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
661+
// CHECK-LABEL: func @static_input_without_cast
662+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
663+
func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
664+
%c0 = arith.constant 0 : index
665+
%c1 = arith.constant 1 : index
666+
%c2 = arith.constant 2 : index
667+
%0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
668+
%1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
669+
%2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
670+
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
671+
%4 = linalg.generic {
672+
indexing_maps = [#map, #map, #map],
673+
iterator_types = ["parallel", "parallel", "parallel"]
674+
} ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<?x?x?xf32>)
675+
outs(%3 : tensor<?x?x?xf32>) {
676+
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
677+
%9 = arith.addf %arg2, %arg3 : f32
678+
linalg.yield %9 : f32
679+
} -> (tensor<?x?x?xf32>)
680+
%5 = tensor.cast %4 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
681+
return %5 : tensor<2x3x4xf32>
682+
// CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
683+
// CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
684+
// CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
685+
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
686+
}
687+
688+
// -----
689+
690+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
691+
// CHECK-LABEL: func @static_input_with_cast
692+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
693+
func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
694+
%c0 = arith.constant 0 : index
695+
%c1 = arith.constant 1 : index
696+
%c2 = arith.constant 2 : index
697+
%0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
698+
%1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
699+
%2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
700+
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
701+
%4 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32>
702+
%5 = linalg.generic {
703+
indexing_maps = [#map, #map, #map],
704+
iterator_types = ["parallel", "parallel", "parallel"]
705+
} ins(%arg0, %4 : tensor<2x3x4xf32>, tensor<2x?x?xf32>)
706+
outs(%3 : tensor<?x?x?xf32>) {
707+
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
708+
%9 = arith.addf %arg2, %arg3 : f32
709+
linalg.yield %9 : f32
710+
} -> (tensor<?x?x?xf32>)
711+
%6 = tensor.cast %5 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
712+
return %6: tensor<2x3x4xf32>
713+
// CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
714+
// CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
715+
// CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
716+
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
717+
}
718+
719+
// -----
720+
721+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
722+
// CHECK-LABEL: func @static_output_with_cast
723+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
724+
func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
725+
%c0 = arith.constant 0 : index
726+
%c1 = arith.constant 1 : index
727+
%c2 = arith.constant 2 : index
728+
%0 = tensor.dim %arg2, %c0 : tensor<2x3x4xf32>
729+
%1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32>
730+
%2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32>
731+
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
732+
%4 = tensor.cast %3 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
733+
%5 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32>
734+
%6 = linalg.generic {
735+
indexing_maps = [#map, #map, #map],
736+
iterator_types = ["parallel", "parallel", "parallel"]
737+
} ins(%arg0, %5 : tensor<?x?x?xf32>, tensor<2x?x?xf32>)
738+
outs(%4 : tensor<2x3x4xf32>) {
739+
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
740+
%9 = arith.addf %arg3, %arg4 : f32
741+
linalg.yield %9 : f32
742+
} -> (tensor<2x3x4xf32>)
743+
return %6: tensor<2x3x4xf32>
744+
// CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
745+
// CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
746+
// CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
747+
// CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
748+
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
749+
}
750+
751+
// -----
752+
753+
// This test checks the folding of tensor.cast operation when the source value of cast
754+
// has more static information than the destination value.
755+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
756+
// CHECK-LABEL: func @cast_source
757+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
758+
func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
759+
%c0 = arith.constant 0 : index
760+
%c1 = arith.constant 1 : index
761+
%c2 = arith.constant 2 : index
762+
%0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
763+
%1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
764+
%2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
765+
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
766+
%4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor<2x?x?xf32>
767+
%5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor<2x?x?xf32>
768+
%6 = linalg.generic {
769+
indexing_maps = [#map, #map, #map],
770+
iterator_types = ["parallel", "parallel", "parallel"]
771+
} ins(%4, %5 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
772+
outs(%3 : tensor<?x?x?xf32>) {
773+
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
774+
%9 = arith.addf %arg2, %arg3 : f32
775+
linalg.yield %9 : f32
776+
} -> (tensor<?x?x?xf32>)
777+
%7 = tensor.cast %6 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
778+
return %7: tensor<2x3x4xf32>
779+
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
780+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
781+
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
782+
}

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -533,27 +533,28 @@ func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
533533

534534
// -----
535535

536-
func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor<?xi64>) -> tensor<1xi64> {
537-
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
538-
%1 = linalg.init_tensor [1] : tensor<1xi64>
536+
func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi64>) -> tensor<2xi64> {
537+
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64>
538+
%1 = linalg.init_tensor [2] : tensor<2xi64>
539539
%2 = linalg.generic
540540
{indexing_maps = [affine_map<(d0) -> (d0)>,
541541
affine_map<(d0) -> (d0)>,
542542
affine_map<(d0) -> (d0)>],
543543
iterator_types = ["parallel"]}
544-
ins(%0, %arg1 : tensor<1xi64>, tensor<?xi64>)
545-
outs(%1 : tensor<1xi64>) {
544+
ins(%0, %arg1 : tensor<2xi64>, tensor<?xi64>)
545+
outs(%1 : tensor<2xi64>) {
546546
^bb0(%arg4: i64, %arg5: i64, %arg6: i64):
547547
%3 = arith.addi %arg4, %arg5 : i64
548548
linalg.yield %3 : i64
549-
} -> tensor<1xi64>
550-
return %2 : tensor<1xi64>
549+
} -> tensor<2xi64>
550+
return %2 : tensor<2xi64>
551551
}
552552

553553
// CHECK: func @no_fuse_mismatched_dynamism
554-
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xi64>
554+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64>
555555
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
556556
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
557+
// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?xi64> to tensor<2xi64>
557558
// CHECK: %[[GENERIC:.+]] = linalg.generic
558-
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor<?xi64>)
559+
// CHECK-SAME: ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>)
559560
// CHECK: return %[[GENERIC]]

0 commit comments

Comments
 (0)