Skip to content

Commit 6bb0ab0

Browse files
committed
[MLIR] Propagate unpack through element-wise ops
Introduce `pushDownUnPackOpThroughElemGenericOp` to propagate producer unpack operation through an element-wise linalg.generic operation. This pattern complements `BubbleUpPackOpThroughElemGenericOp`. The general idea is to bubble up tensor.pack as much as possible while pushing down tensor.unpack as much as possible, and canonicalize away symmetrical tensor.pack and tensor.unpack operations. Currently, `pushDownUnPackOpThroughElemGenericOp` expects a single tensor.unpack operation as the producer of one of the linalg.generic's operands. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D142523
1 parent 95e49f5 commit 6bb0ab0

File tree

2 files changed

+284
-39
lines changed

2 files changed

+284
-39
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 172 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,22 @@ struct PackInfo {
4646
SmallVector<int64_t> outerDimsOnDomainPerm;
4747
};
4848

49-
static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
50-
tensor::PackOp packOp) {
49+
template <typename OpTy>
50+
static PackInfo getPackingInfoFromOperand(AffineMap indexingMap,
51+
OpTy packOrUnPackOp) {
52+
static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
53+
"applies to only pack or unpack operations");
5154
LLVM_DEBUG(
52-
{ llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; });
55+
{ llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
5356
PackInfo packInfo;
5457
int64_t origNumDims = indexingMap.getNumDims();
5558
SmallVector<AffineExpr> exprs(indexingMap.getResults());
56-
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
59+
ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
5760
for (auto [index, innerDimPos, tileSize] :
5861
llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
59-
innerDimsPos, packOp.getMixedTiles())) {
62+
innerDimsPos, packOrUnPackOp.getMixedTiles())) {
6063
int64_t domainDimPos =
61-
exprs[innerDimPos].cast<AffineDimExpr>().getPosition();
64+
exprs[innerDimPos].template cast<AffineDimExpr>().getPosition();
6265
packInfo.tiledDimsPos.push_back(domainDimPos);
6366
packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
6467
packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
@@ -71,7 +74,7 @@ static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
7174
});
7275
}
7376

74-
for (auto dim : packOp.getOuterDimsPerm())
77+
for (auto dim : packOrUnPackOp.getOuterDimsPerm())
7578
packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
7679
if (!packInfo.outerDimsOnDomainPerm.empty()) {
7780
LLVM_DEBUG({
@@ -209,6 +212,35 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
209212
return std::make_tuple(packedOperand, indexingMap);
210213
}
211214

215+
/// Pack an element-wise genericOp and return it.
216+
static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp,
217+
Value dest, AffineMap packedOutIndexingMap,
218+
const PackInfo &packInfo) {
219+
Location loc = genericOp.getLoc();
220+
SmallVector<Value> inputOperands;
221+
SmallVector<AffineMap> indexingMaps;
222+
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
223+
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
224+
rewriter, loc, packInfo, genericOp, inputOperand);
225+
inputOperands.push_back(packedOperand);
226+
indexingMaps.push_back(packedIndexingMap);
227+
}
228+
229+
int64_t numInnerLoops = packInfo.getNumTiledLoops();
230+
SmallVector<utils::IteratorType> iterTypes =
231+
genericOp.getIteratorTypesArray();
232+
iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
233+
234+
indexingMaps.push_back(packedOutIndexingMap);
235+
236+
auto newGenericOp = rewriter.create<linalg::GenericOp>(
237+
loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
238+
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
239+
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
240+
newGenericOp.getRegion().begin());
241+
return newGenericOp;
242+
}
243+
212244
/// Bubbles up tensor.pack op through elementwise generic op. This
213245
/// swap pack(generic) to generic(pack). The new generic op works on packed
214246
/// domain; pack ops are created for input and output operands. E.g.,
@@ -275,29 +307,13 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
275307
return failure();
276308

277309
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
278-
auto packInfo = getPackingInfoFromConsumer(
310+
auto packInfo = getPackingInfoFromOperand(
279311
genericOp.getMatchingIndexingMap(opOperand), packOp);
280312

281-
Location loc = packOp.getLoc();
282-
SmallVector<Value> inputOperands;
283-
SmallVector<AffineMap> indexingMaps;
284-
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
285-
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
286-
rewriter, loc, packInfo, genericOp, inputOperand);
287-
inputOperands.push_back(packedOperand);
288-
indexingMaps.push_back(packedIndexingMap);
289-
}
290-
291-
int64_t numInnerLoops = packInfo.getNumTiledLoops();
292-
SmallVector<utils::IteratorType> iterTypes =
293-
genericOp.getIteratorTypesArray();
294-
iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
295-
296313
// Rebuild the indexing map for the corresponding init operand.
297314
auto [packedOutOperand, packedOutIndexingMap] =
298-
getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp,
299-
opOperand);
300-
indexingMaps.push_back(packedOutIndexingMap);
315+
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo,
316+
genericOp, opOperand);
301317

302318
// We'll replace the init operand with the destination of pack op if the init
303319
// operand has not users in the body of the linalg.generic (pure elementwise).
@@ -306,15 +322,12 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
306322
Value dest = (genericOp.getRegionOutputArgs()[0].use_empty())
307323
? packOp.getDest()
308324
: packedOutOperand;
309-
auto newGenericOp = rewriter.create<linalg::GenericOp>(
310-
loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
311-
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
312-
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
313-
newGenericOp.getRegion().begin());
314-
return newGenericOp;
325+
326+
return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap,
327+
packInfo);
315328
}
316329

317-
// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
330+
/// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
318331
struct BubbleUpPackOpThroughElemGenericOpPattern
319332
: public OpRewritePattern<tensor::PackOp> {
320333
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
@@ -328,10 +341,134 @@ struct BubbleUpPackOpThroughElemGenericOpPattern
328341
return success();
329342
}
330343
};
344+
345+
// TODO: Relax this restriction. We should unpack an elementwise also
346+
// in the presence of multiple unpack ops as producers.
347+
/// Return the unpacked operand, if present, for the current generic op.
348+
static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
349+
OpOperand *unPackedOperand = nullptr;
350+
for (OpOperand &operand : genericOp->getOpOperands()) {
351+
auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
352+
if (!unPackOp)
353+
continue;
354+
if (unPackedOperand)
355+
return failure();
356+
unPackedOperand = &operand;
357+
}
358+
if (!unPackedOperand)
359+
return failure();
360+
return unPackedOperand;
361+
}
362+
363+
/// Push down a tensor.unpack op through elementwise generic op.
364+
/// The new generic op works on packed domain; pack ops are created for input
365+
/// and output operands. A tensor.unpack op is inserted right after the packed
366+
/// generic. E.g.
367+
///
368+
/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
369+
///
370+
/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
371+
///
372+
/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
373+
/// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
374+
/// inner_dims_pos = [3] inner_tiles = [32] into %0
375+
/// %2 = linalg.generic {indexing_maps = [#map],
376+
/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
377+
/// outs(%1 : tensor<12x56x56x64xf32>) {
378+
/// ^bb0(%out : f32):
379+
/// linalg.yield %out : f32
380+
/// } -> tensor<12x56x56x64xf32>
381+
///
382+
/// will be converted to
383+
///
384+
/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
385+
///
386+
/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
387+
/// %1 = linalg.generic {indexing_maps = [#map],
388+
/// iterator_types = ["parallel", "parallel", "parallel",
389+
/// "parallel", "parallel"]}
390+
/// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
391+
/// ^bb0(%out : f32):
392+
/// linalg.yield %out : f32
393+
/// } -> tensor<12x2x56x56x32xf32>
394+
/// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
395+
/// inner_dims_pos = [3] inner_tiles = [32] into %0
396+
///
397+
static FailureOr<std::tuple<GenericOp, Value>>
398+
pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter,
399+
GenericOp genericOp) {
400+
if (!isElementwise(genericOp))
401+
return failure();
402+
if (genericOp.getNumResults() != 1)
403+
return failure();
404+
405+
// Collect the unPacked operand, if present.
406+
auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
407+
if (failed(maybeUnPackedOperand))
408+
return failure();
409+
OpOperand *unPackedOperand = *(maybeUnPackedOperand);
410+
411+
// Extract packing information.
412+
tensor::UnPackOp producerUnPackOp =
413+
unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
414+
assert(producerUnPackOp && "expect a valid UnPackOp");
415+
auto packInfo = getPackingInfoFromOperand(
416+
genericOp.getMatchingIndexingMap(unPackedOperand), producerUnPackOp);
417+
418+
// Rebuild the indexing map for the corresponding init operand.
419+
auto [packedOutOperand, packedOutIndexingMap] =
420+
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo,
421+
genericOp, genericOp.getDpsInitOperand(0));
422+
423+
// If the dps init operand of the generic is a tensor.empty, do not pack it
424+
// and forward the new tensor.empty as a destination.
425+
Value dest = packedOutOperand;
426+
if (auto initTensor = genericOp.getDpsInitOperand(0)
427+
->get()
428+
.getDefiningOp<tensor::EmptyOp>()) {
429+
if (auto packOp = packedOutOperand.getDefiningOp<tensor::PackOp>())
430+
dest = packOp.getDest();
431+
}
432+
433+
// Pack the genericOp.
434+
GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest,
435+
packedOutIndexingMap, packInfo);
436+
437+
auto unPackOp = unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
438+
// Insert an unPackOp right after the packed generic.
439+
Value unPackOpRes =
440+
rewriter
441+
.create<tensor::UnPackOp>(
442+
genericOp.getLoc(),
443+
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
444+
unPackOp.getDest(), producerUnPackOp.getInnerDimsPos(),
445+
producerUnPackOp.getMixedTiles(),
446+
producerUnPackOp.getOuterDimsPerm())
447+
.getResult();
448+
449+
return std::make_tuple(newGenericOp, unPackOpRes);
450+
}
451+
452+
// Wrapper pattern that applies pushDownUnPackOpThroughElemGenericOp method.
453+
struct PushDownUnPackOpThroughElemGenericOp
454+
: public OpRewritePattern<GenericOp> {
455+
using OpRewritePattern<GenericOp>::OpRewritePattern;
456+
457+
LogicalResult matchAndRewrite(GenericOp genericOp,
458+
PatternRewriter &rewriter) const override {
459+
auto genericAndRepl =
460+
pushDownUnPackOpThroughElemGenericOp(rewriter, genericOp);
461+
if (failed(genericAndRepl))
462+
return failure();
463+
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
464+
return success();
465+
}
466+
};
467+
331468
} // namespace
332469

333470
void mlir::linalg::populateDataLayoutPropagationPatterns(
334471
RewritePatternSet &patterns) {
335-
patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern>(
336-
patterns.getContext());
472+
patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern,
473+
PushDownUnPackOpThroughElemGenericOp>(patterns.getContext());
337474
}

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,15 +352,123 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
352352
// CHECK: func.func @elem_pack_transpose_outer_dims
353353
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
354354
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
355-
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
356-
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
357-
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
358-
// CHECK-SAME: into %[[ARG0_EMPTY]]
359355
// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
360356
// CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]]
361357
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
362358
// CHECK-SAME: into %[[ARG1_EMPTY]]
359+
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
360+
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
361+
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
362+
// CHECK-SAME: into %[[ARG0_EMPTY]]
363363
// CHECK: %[[RES:.+]] = linalg.generic
364364
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
365365
// CHECK-SAME: ins(%[[PACKED_ARG0]]
366366
// CHECK-SAME: outs(%[[PACKED_ARG1]]
367+
368+
// -----
369+
370+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
371+
372+
func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
373+
%0 = tensor.empty() : tensor<12x56x56x64xf32>
374+
%1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
375+
%2 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%1 : tensor<12x56x56x64xf32>) {
376+
^bb0(%out: f32):
377+
%3 = arith.addf %out, %out : f32
378+
linalg.yield %3 : f32
379+
} -> tensor<12x56x56x64xf32>
380+
return %2 : tensor<12x56x56x64xf32>
381+
}
382+
383+
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
384+
// CHECK: func.func @unpack_on_output
385+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
386+
// CHECK: %[[ARG0_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
387+
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
388+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
389+
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
390+
// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
391+
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
392+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
393+
// CHECK-SAME: into %[[ARG0_EMPTY_PACK]]
394+
// CHECK: %[[RES:.+]] = linalg.generic
395+
// CHECK-SAME: indexing_maps = [#[[MAP]]]
396+
// CHECK-SAME: outs(%[[PACKED_ARG0]]
397+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
398+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
399+
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
400+
401+
// -----
402+
403+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
404+
405+
func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf32>) -> tensor<12x56x56x64xf32> {
406+
%0 = tensor.empty() : tensor<12x56x56x64xf32>
407+
%1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
408+
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
409+
^bb0(%in: f32, %out: f32):
410+
%3 = arith.addf %in, %out : f32
411+
linalg.yield %3 : f32
412+
} -> tensor<12x56x56x64xf32>
413+
return %2 : tensor<12x56x56x64xf32>
414+
}
415+
416+
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
417+
// CHECK: func.func @unpack_on_input
418+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
419+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
420+
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
421+
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
422+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
423+
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
424+
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
425+
// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
426+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
427+
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
428+
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
429+
// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
430+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
431+
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
432+
// CHECK: %[[RES:.+]] = linalg.generic
433+
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
434+
// CHECK-SAME: ins(%[[ARG0_PACK]]
435+
// CHECK-SAME: outs(%[[ARG1_PACK]]
436+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
437+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
438+
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
439+
440+
// -----
441+
442+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
443+
444+
func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
445+
%init = tensor.empty() : tensor<12x56x56x64xf32>
446+
%0 = tensor.empty() : tensor<12x56x56x64xf32>
447+
%1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
448+
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
449+
^bb0(%in: f32, %out: f32):
450+
%3 = arith.addf %in, %in : f32
451+
linalg.yield %3 : f32
452+
} -> tensor<12x56x56x64xf32>
453+
return %2 : tensor<12x56x56x64xf32>
454+
}
455+
456+
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
457+
// CHECK: func.func @forward_tensor_empty
458+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
459+
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
460+
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
461+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
462+
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
463+
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
464+
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
465+
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
466+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
467+
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
468+
// CHECK: %[[RES:.+]] = linalg.generic
469+
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
470+
// CHECK-SAME: ins(%[[PACKED_ARG0]]
471+
// CHECK-SAME: outs(%[[DEST]]
472+
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
473+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
474+
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]

0 commit comments

Comments
 (0)