Skip to content

Commit 0775088

Browse files
authored
[mlir] Rename GeneralizeOuterUnitDims{Un}PackOpPatterns (llvm#116439)
Renames: * `GeneralizeOuterUnitDimsPackOpPattern`, * `GeneralizeOuterUnitDimsUnPackOpPattern`, as * `DecomposeOuterUnitDimsPackOpPattern`, * `DecomposeOuterUnitDimsUnPackOpPattern`, respectively. The new name better describes the underlying transformation.
1 parent 91c1699 commit 0775088

File tree

11 files changed

+37
-33
lines changed

11 files changed

+37
-33
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

+4-5
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,12 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op<Transform_Dialect,
4141
let assemblyFormat = "attr-dict";
4242
}
4343

44-
def ApplyGeneralizeTensorPackUnpackPatternsOp
45-
: Op<Transform_Dialect, "apply_patterns.linalg.generalize_pack_unpack",
44+
def ApplyDecomposeTensorPackUnpackPatternsOp
45+
: Op<Transform_Dialect, "apply_patterns.linalg.decompose_pack_unpack",
4646
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
4747
let description = [{
48-
Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to
49-
decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires
50-
all outer dims to be unit.
48+
Collect patterns to decompose tensor.pack and tensor.unpack into e.g.
49+
tensor::PadOp, linalg::transposeOp Ops. Requires all outer dims to be unit.
5150
}];
5251

5352
let assemblyFormat = "attr-dict";

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
15481548
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
15491549
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
15501550
/// ```
1551-
struct GeneralizeOuterUnitDimsPackOpPattern
1551+
struct DecomposeOuterUnitDimsPackOpPattern
15521552
: public OpRewritePattern<tensor::PackOp> {
15531553
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
15541554
LogicalResult matchAndRewrite(tensor::PackOp packOp,
@@ -1558,7 +1558,7 @@ struct GeneralizeOuterUnitDimsPackOpPattern
15581558
/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
15591559
/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
15601560
/// being all 1s.
1561-
struct GeneralizeOuterUnitDimsUnPackOpPattern
1561+
struct DecomposeOuterUnitDimsUnPackOpPattern
15621562
: public OpRewritePattern<tensor::UnPackOp> {
15631563
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
15641564
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp,
@@ -1686,7 +1686,7 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16861686
/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
16871687
/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
16881688
/// outer dims to be unit.
1689-
void populateGeneralizePatterns(RewritePatternSet &patterns);
1689+
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);
16901690

16911691
/// Populates patterns to transform linalg.conv_2d_xxx operations into
16921692
/// linalg.generic (for img2col packing) and linalg.matmul.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,9 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
229229
linalg::populateEraseUnnecessaryInputsPatterns(patterns);
230230
}
231231

232-
void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns(
232+
void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
233233
RewritePatternSet &patterns) {
234-
linalg::populateGeneralizePatterns(patterns);
234+
linalg::populateDecomposePackUnpackPatterns(patterns);
235235
}
236236

237237
void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,7 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11381138
return perm;
11391139
}
11401140

1141-
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1141+
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11421142
tensor::PackOp packOp, PatternRewriter &rewriter) const {
11431143
// TODO: support the case that outer dimensions are not all 1s. A
11441144
// tensor.expand_shape will be generated in this case.
@@ -1239,7 +1239,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12391239
return success();
12401240
}
12411241

1242-
LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
1242+
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12431243
tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
12441244
int64_t srcRank = unpackOp.getSourceRank();
12451245
int64_t destRank = unpackOp.getDestRank();
@@ -1619,7 +1619,7 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16191619
patterns.getContext(), benefit);
16201620
}
16211621

1622-
void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) {
1622+
void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
16231623
// TODO: Add and test patterns for tensor.unpack
1624-
patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(patterns.getContext());
1624+
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
16251625
}

mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir renamed to mlir/test/Dialect/Linalg/decompose-tensor-pack-tile.mlir

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file -transform-interpreter --canonicalize \
2+
// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-pack.mlir' \
3+
// RUN: -transform-interpreter=entry-point=decompose_pack \
4+
// RUN: -transform-interpreter %s | FileCheck %s
25

36
func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> {
47
%0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x128x64xf32> -> tensor<1x1x4x8x8x32xf32>

mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir renamed to mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: mlir-opt --transform-preload-library='transform-library-paths=%p/td/generalize-pack.mlir' -split-input-file --transform-interpreter %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file \
2+
// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-pack.mlir' \
3+
// RUN: -transform-interpreter=entry-point=decompose_pack %s | FileCheck %s
24

35
func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
46
%c8 = arith.constant 8 : index

mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir renamed to mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack" %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s
22

33
func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
44
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>

mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir renamed to mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-unpack" %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s
22

33
func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> {
44
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32>

mlir/test/Dialect/Linalg/td/generalize-pack.mlir renamed to mlir/test/Dialect/Linalg/td/decompose-pack.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module @transforms attributes { transform.with_named_sequence } {
2-
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
2+
transform.named_sequence @decompose_pack(%module: !transform.any_op {transform.readonly}) {
33
%pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op
44

55
%1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
66
transform.apply_patterns to %1 {
7-
transform.apply_patterns.linalg.generalize_pack_unpack
7+
transform.apply_patterns.linalg.decompose_pack_unpack
88
} : !transform.any_op
99

1010
transform.yield

mlir/test/Integration/Dialect/Linalg/CPU/pack-dynamic-inner-tile.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// DEFINE: %{compile} = mlir-opt %s \
22
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule |\
3-
// DEFINE: mlir-opt --test-linalg-transform-patterns="test-generalize-tensor-pack"\
3+
// DEFINE: mlir-opt --test-linalg-transform-patterns="test-decompose-tensor-pack"\
44
// DEFINE: --test-transform-dialect-erase-schedule \
55
// DEFINE: -one-shot-bufferize="bufferize-function-boundaries" \
66
// DEFINE: -buffer-deallocation-pipeline="private-function-dynamic-ownership" \

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ struct TestLinalgTransforms
7474
*this, "test-generalize-pad-tensor",
7575
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
7676
llvm::cl::init(false)};
77-
Option<bool> testGeneralizeTensorPackOp{
78-
*this, "test-generalize-tensor-pack",
77+
Option<bool> testDecomposeTensorPackOp{
78+
*this, "test-decompose-tensor-pack",
7979
llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
8080
"of tensor and Linalg ops"),
8181
llvm::cl::init(false)};
82-
Option<bool> testGeneralizeTensorUnPackOp{
83-
*this, "test-generalize-tensor-unpack",
82+
Option<bool> testDecomposeTensorUnPackOp{
83+
*this, "test-decompose-tensor-unpack",
8484
llvm::cl::desc(
8585
"Test transform that generalizes unpack ops into a sequence "
8686
"of tensor and Linalg ops"),
@@ -172,15 +172,15 @@ static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
172172
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
173173
}
174174

175-
static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) {
175+
static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) {
176176
RewritePatternSet patterns(funcOp.getContext());
177-
patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(funcOp.getContext());
177+
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext());
178178
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
179179
}
180180

181-
static void applyGeneralizeTensorUnPackPatterns(func::FuncOp funcOp) {
181+
static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) {
182182
RewritePatternSet patterns(funcOp.getContext());
183-
patterns.add<GeneralizeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
183+
patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
184184
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
185185
}
186186

@@ -237,10 +237,10 @@ void TestLinalgTransforms::runOnOperation() {
237237
return applyLinalgToVectorPatterns(getOperation());
238238
if (testGeneralizePadTensor)
239239
return applyGeneralizePadTensorPatterns(getOperation());
240-
if (testGeneralizeTensorPackOp)
241-
return applyGeneralizeTensorPackPatterns(getOperation());
242-
if (testGeneralizeTensorUnPackOp)
243-
return applyGeneralizeTensorUnPackPatterns(getOperation());
240+
if (testDecomposeTensorPackOp)
241+
return applyDecomposeTensorPackPatterns(getOperation());
242+
if (testDecomposeTensorUnPackOp)
243+
return applyDecomposeTensorUnPackPatterns(getOperation());
244244
if (testSwapSubTensorPadTensor)
245245
return applyExtractSliceOfPadTensorSwapPattern(getOperation());
246246
if (testBubbleUpExtractSliceOpPattern)

0 commit comments

Comments
 (0)