Skip to content

Commit a9e68db

Browse files
author
MaheshRavishankar
committed
[mlir] Add canonicaliations for subtensor_insert operation.
Add canonicalizers to subtensor_insert operations need canonicalizers that propagate the constant arguments within offsets, sizes and strides. Also add pattern to propogate tensor_cast operations. Differential Revision: https://reviews.llvm.org/D97704
1 parent 6dbea3e commit a9e68db

File tree

3 files changed

+138
-0
lines changed

3 files changed

+138
-0
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3048,6 +3048,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
30483048
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
30493049
}];
30503050

3051+
let hasCanonicalizer = 1;
30513052
let hasFolder = 1;
30523053
}
30533054

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3795,6 +3795,95 @@ OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
37953795
return OpFoldResult();
37963796
}
37973797

3798+
namespace {
3799+
/// Pattern to rewrite a subtensor_insert op with constant arguments.
3800+
class SubTensorInsertOpConstantArgumentFolder final
3801+
: public OpRewritePattern<SubTensorInsertOp> {
3802+
public:
3803+
using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
3804+
3805+
LogicalResult matchAndRewrite(SubTensorInsertOp subTensorInsertOp,
3806+
PatternRewriter &rewriter) const override {
3807+
// No constant operand, just return.
3808+
if (llvm::none_of(subTensorInsertOp.getOperands(), [](Value operand) {
3809+
return matchPattern(operand, m_ConstantIndex());
3810+
}))
3811+
return failure();
3812+
3813+
// At least one of offsets/sizes/strides is a new constant.
3814+
// Form the new list of operands and constant attributes from the existing.
3815+
SmallVector<OpFoldResult> mixedOffsets(subTensorInsertOp.getMixedOffsets());
3816+
SmallVector<OpFoldResult> mixedSizes(subTensorInsertOp.getMixedSizes());
3817+
SmallVector<OpFoldResult> mixedStrides(subTensorInsertOp.getMixedStrides());
3818+
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
3819+
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
3820+
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
3821+
3822+
// Create the new op in canonical form.
3823+
Value source = subTensorInsertOp.source();
3824+
RankedTensorType sourceType = source.getType().cast<RankedTensorType>();
3825+
SmallVector<int64_t, 4> shape = llvm::to_vector<4>(
3826+
llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t {
3827+
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
3828+
return attr.cast<IntegerAttr>().getInt();
3829+
return ShapedType::kDynamicSize;
3830+
}));
3831+
RankedTensorType newSourceType =
3832+
RankedTensorType::get(shape, sourceType.getElementType());
3833+
Location loc = subTensorInsertOp.getLoc();
3834+
if (sourceType != newSourceType)
3835+
source = rewriter.create<tensor::CastOp>(loc, newSourceType, source);
3836+
rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
3837+
subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets,
3838+
mixedSizes, mixedStrides);
3839+
return success();
3840+
}
3841+
};
3842+
3843+
/// Fold tensor_casts with subtensor_insert operations.
3844+
struct SubTensorInsertOpCastFolder final
3845+
: public OpRewritePattern<SubTensorInsertOp> {
3846+
using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
3847+
3848+
LogicalResult matchAndRewrite(SubTensorInsertOp subTensorOp,
3849+
PatternRewriter &rewriter) const override {
3850+
if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) {
3851+
return matchPattern(operand, m_ConstantIndex());
3852+
}))
3853+
return failure();
3854+
3855+
auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
3856+
auto castOp = v.getDefiningOp<tensor::CastOp>();
3857+
if (!castOp || !canFoldIntoConsumerOp(castOp))
3858+
return llvm::None;
3859+
return castOp.source();
3860+
};
3861+
Optional<Value> sourceCastSource = getSourceOfCastOp(subTensorOp.source());
3862+
Optional<Value> destCastSource = getSourceOfCastOp(subTensorOp.dest());
3863+
if (!sourceCastSource && !destCastSource &&
3864+
subTensorOp.dest().getType() == subTensorOp.getResult().getType())
3865+
return failure();
3866+
3867+
auto newOp = rewriter.create<SubTensorInsertOp>(
3868+
subTensorOp.getLoc(),
3869+
(sourceCastSource ? *sourceCastSource : subTensorOp.source()),
3870+
(destCastSource ? *destCastSource : subTensorOp.dest()),
3871+
subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
3872+
subTensorOp.getMixedStrides());
3873+
3874+
rewriter.replaceOpWithNewOp<tensor::CastOp>(subTensorOp,
3875+
subTensorOp.getType(), newOp);
3876+
return success();
3877+
}
3878+
};
3879+
} // namespace
3880+
3881+
void SubTensorInsertOp::getCanonicalizationPatterns(
3882+
OwningRewritePatternList &results, MLIRContext *context) {
3883+
results.insert<SubTensorInsertOpConstantArgumentFolder,
3884+
SubTensorInsertOpCastFolder>(context);
3885+
}
3886+
37983887
//===----------------------------------------------------------------------===//
37993888
// TensorLoadOp
38003889
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,51 @@ func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<
252252
%res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
253253
return %res : tensor<4x6x16x32xi8>
254254
}
255+
256+
// -----
257+
258+
func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
259+
%arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
260+
%c0 = constant 0 : index
261+
%c1 = constant 1 : index
262+
%c2 = constant 2 : index
263+
%c8 = constant 8 : index
264+
%0 = dim %arg0, %c1 : tensor<2x?xi32>
265+
%1 = tensor.extract %arg1[] : tensor<i32>
266+
%2 = tensor.generate %arg2, %c8 {
267+
^bb0(%arg4: index, %arg5: index):
268+
tensor.yield %1 : i32
269+
} : tensor<?x?xi32>
270+
%3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
271+
return %3 : tensor<?x?xi32>
272+
}
273+
// CHECK-LABEL: func @subtensor_canonicalize
274+
// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
275+
// CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
276+
// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]
277+
// CHECK: return %[[CAST]]
278+
279+
// -----
280+
281+
func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
282+
%c0 = constant 0 : index
283+
%c1 = constant 1 : index
284+
%c2 = constant 2 : index
285+
%c9 = constant 9 : index
286+
%c3 = constant 3 : index
287+
%2 = tensor.extract %arg1[] : tensor<i32>
288+
%4 = tensor.generate %c3, %c9 {
289+
^bb0(%arg2: index, %arg3: index):
290+
tensor.yield %2 : i32
291+
} : tensor<?x?xi32>
292+
%5 = subtensor_insert %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
293+
%6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
294+
return %6 : tensor<3x9xi32>
295+
}
296+
// CHECK-LABEL: func @subtensor_insert_output_dest_canonicalize
297+
// CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32>
298+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
299+
// CHECK: %[[PAD:.+]] = tensor.extract %[[ARG1]]
300+
// CHECK: %[[GENERATE:.+]] = tensor.generate
301+
// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]]
302+
// CHECK: return %[[RESULT]]

0 commit comments

Comments
 (0)