Skip to content

Commit 30d542f

Browse files
committed
[MLIR][Tensor] Introduce a pattern to propagate through tensor.pad
Introduce a pattern to 'push down' a `tensor.unpack` through a `tensor.pad`. The propagation happens if the unpack does not touch the padded dimensions. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D143907
1 parent 2872987 commit 30d542f

File tree

2 files changed

+128
-3
lines changed

2 files changed

+128
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,69 @@ struct PushDownUnPackOpThroughElemGenericOp
465465
}
466466
};
467467

468+
/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
469+
/// add as many zero padding dimensions in `high` and `low` based on the number
470+
/// of point loops.
471+
struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
472+
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
473+
474+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
475+
PatternRewriter &rewriter) const override {
476+
tensor::UnPackOp unpackOp =
477+
padOp.getSource().getDefiningOp<tensor::UnPackOp>();
478+
if (!unpackOp)
479+
return failure();
480+
481+
Location loc = padOp.getLoc();
482+
// Bail out if one of the padded dimension is a tiled one.
483+
llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
484+
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
485+
llvm::SmallBitVector innerDims(paddedDims.size());
486+
for (int64_t dim : innerDimsPos)
487+
innerDims.flip(dim);
488+
if (paddedDims.anyCommon(innerDims))
489+
return failure();
490+
491+
Value paddingVal = padOp.getConstantPaddingValue();
492+
if (!paddingVal)
493+
return failure();
494+
495+
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
496+
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
497+
SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
498+
SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
499+
if (!outerDimsPerm.empty()) {
500+
applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
501+
applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
502+
}
503+
// Add zero padding for the point loops.
504+
size_t pointLoopsSize = innerDimsPos.size();
505+
lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
506+
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
507+
508+
auto newPadOp = rewriter.create<tensor::PadOp>(
509+
loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
510+
paddingVal, padOp.getNofold());
511+
512+
// Inject the tensor.unpack right after the packed padOp.
513+
Value outputUnPack = rewriter.create<tensor::EmptyOp>(
514+
loc, padOp.getResultType().getShape(),
515+
padOp.getResultType().getElementType());
516+
517+
Value replacement = rewriter.create<tensor::UnPackOp>(
518+
loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
519+
unpackOp.getMixedTiles(), outerDimsPerm);
520+
rewriter.replaceOp(padOp, replacement);
521+
return success();
522+
}
523+
};
524+
468525
} // namespace
469526

470527
void mlir::linalg::populateDataLayoutPropagationPatterns(
471528
RewritePatternSet &patterns) {
472-
patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern,
473-
PushDownUnPackOpThroughElemGenericOp>(patterns.getContext());
529+
patterns
530+
.insert<BubbleUpPackOpThroughElemGenericOpPattern,
531+
PushDownUnPackOpThroughElemGenericOp, PushDownUnPackThroughPadOp>(
532+
patterns.getContext());
474533
}

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,4 +471,70 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
471471
// CHECK-SAME: outs(%[[DEST]]
472472
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
473473
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
474-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
474+
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
475+
476+
// -----
477+
478+
func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> {
479+
%cst = arith.constant 0.000000e+00 : f32
480+
%0 = tensor.empty() : tensor<1x56x56x64xf32>
481+
%1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
482+
%padded = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] {
483+
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
484+
tensor.yield %cst : f32
485+
} : tensor<1x56x56x64xf32> to tensor<1x58x58x64xf32>
486+
return %padded : tensor<1x58x58x64xf32>
487+
}
488+
489+
// CHECK: func.func @pad_valid_propagation(
490+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
491+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
492+
// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
493+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
494+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
495+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
496+
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
497+
498+
// -----
499+
500+
func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> {
501+
%cst = arith.constant 0.000000e+00 : f32
502+
%0 = tensor.empty() : tensor<1x56x56x64xf32>
503+
%1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
504+
%padded = tensor.pad %1 low[1, 1, 1, 0] high[0, 1, 1, 0] {
505+
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
506+
tensor.yield %cst : f32
507+
} : tensor<1x56x56x64xf32> to tensor<2x58x58x64xf32>
508+
return %padded : tensor<2x58x58x64xf32>
509+
}
510+
511+
// CHECK: func.func @pad_valid_propagation(
512+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
513+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
514+
// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
515+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32>
516+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
517+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
518+
// CHECK-SAME: into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32>
519+
520+
// -----
521+
522+
func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x66xf32> {
523+
%cst = arith.constant 0.000000e+00 : f32
524+
%0 = tensor.empty() : tensor<1x56x56x64xf32>
525+
%1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
526+
%padded = tensor.pad %1 low[0, 1, 1, 1] high[0, 1, 1, 1] {
527+
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
528+
tensor.yield %cst : f32
529+
} : tensor<1x56x56x64xf32> to tensor<1x58x58x66xf32>
530+
return %padded : tensor<1x58x58x66xf32>
531+
}
532+
533+
// CHECK: func.func @pad_along_unpacked_dim(
534+
// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
535+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
536+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
537+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
538+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
539+
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
540+
// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]

0 commit comments

Comments
 (0)