@@ -46,19 +46,22 @@ struct PackInfo {
46
46
SmallVector<int64_t > outerDimsOnDomainPerm;
47
47
};
48
48
49
- static PackInfo getPackingInfoFromConsumer (AffineMap indexingMap,
50
- tensor::PackOp packOp) {
49
+ template <typename OpTy>
50
+ static PackInfo getPackingInfoFromOperand (AffineMap indexingMap,
51
+ OpTy packOrUnPackOp) {
52
+ static_assert (llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
53
+ " applies to only pack or unpack operations" );
51
54
LLVM_DEBUG (
52
- { llvm::dbgs () << " --- Construct PackInfo From A Consumer ---\n " ; });
55
+ { llvm::dbgs () << " --- Construct PackInfo From an operand ---\n " ; });
53
56
PackInfo packInfo;
54
57
int64_t origNumDims = indexingMap.getNumDims ();
55
58
SmallVector<AffineExpr> exprs (indexingMap.getResults ());
56
- ArrayRef<int64_t > innerDimsPos = packOp .getInnerDimsPos ();
59
+ ArrayRef<int64_t > innerDimsPos = packOrUnPackOp .getInnerDimsPos ();
57
60
for (auto [index, innerDimPos, tileSize] :
58
61
llvm::zip_equal (llvm::seq<unsigned >(0 , innerDimsPos.size ()),
59
- innerDimsPos, packOp .getMixedTiles ())) {
62
+ innerDimsPos, packOrUnPackOp .getMixedTiles ())) {
60
63
int64_t domainDimPos =
61
- exprs[innerDimPos].cast <AffineDimExpr>().getPosition ();
64
+ exprs[innerDimPos].template cast <AffineDimExpr>().getPosition ();
62
65
packInfo.tiledDimsPos .push_back (domainDimPos);
63
66
packInfo.domainDimAndTileMapping [domainDimPos] = tileSize;
64
67
packInfo.tileToPointMapping [domainDimPos] = origNumDims + index;
@@ -71,7 +74,7 @@ static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
71
74
});
72
75
}
73
76
74
- for (auto dim : packOp .getOuterDimsPerm ())
77
+ for (auto dim : packOrUnPackOp .getOuterDimsPerm ())
75
78
packInfo.outerDimsOnDomainPerm .push_back (indexingMap.getDimPosition (dim));
76
79
if (!packInfo.outerDimsOnDomainPerm .empty ()) {
77
80
LLVM_DEBUG ({
@@ -209,6 +212,35 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
209
212
return std::make_tuple (packedOperand, indexingMap);
210
213
}
211
214
215
+ // / Pack an element-wise genericOp and return it.
216
+ static GenericOp packElementWiseOp (RewriterBase &rewriter, GenericOp genericOp,
217
+ Value dest, AffineMap packedOutIndexingMap,
218
+ const PackInfo &packInfo) {
219
+ Location loc = genericOp.getLoc ();
220
+ SmallVector<Value> inputOperands;
221
+ SmallVector<AffineMap> indexingMaps;
222
+ for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
223
+ auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand (
224
+ rewriter, loc, packInfo, genericOp, inputOperand);
225
+ inputOperands.push_back (packedOperand);
226
+ indexingMaps.push_back (packedIndexingMap);
227
+ }
228
+
229
+ int64_t numInnerLoops = packInfo.getNumTiledLoops ();
230
+ SmallVector<utils::IteratorType> iterTypes =
231
+ genericOp.getIteratorTypesArray ();
232
+ iterTypes.append (numInnerLoops, utils::IteratorType::parallel);
233
+
234
+ indexingMaps.push_back (packedOutIndexingMap);
235
+
236
+ auto newGenericOp = rewriter.create <linalg::GenericOp>(
237
+ loc, dest.getType (), inputOperands, dest, indexingMaps, iterTypes,
238
+ /* bodyBuild=*/ nullptr , linalg::getPrunedAttributeList (genericOp));
239
+ rewriter.cloneRegionBefore (genericOp.getRegion (), newGenericOp.getRegion (),
240
+ newGenericOp.getRegion ().begin ());
241
+ return newGenericOp;
242
+ }
243
+
212
244
// / Bubbles up tensor.pack op through elementwise generic op. This
213
245
// / swap pack(generic) to generic(pack). The new generic op works on packed
214
246
// / domain; pack ops are created for input and output operands. E.g.,
@@ -275,29 +307,13 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
275
307
return failure ();
276
308
277
309
OpOperand *opOperand = genericOp.getDpsInitOperand (0 );
278
- auto packInfo = getPackingInfoFromConsumer (
310
+ auto packInfo = getPackingInfoFromOperand (
279
311
genericOp.getMatchingIndexingMap (opOperand), packOp);
280
312
281
- Location loc = packOp.getLoc ();
282
- SmallVector<Value> inputOperands;
283
- SmallVector<AffineMap> indexingMaps;
284
- for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
285
- auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand (
286
- rewriter, loc, packInfo, genericOp, inputOperand);
287
- inputOperands.push_back (packedOperand);
288
- indexingMaps.push_back (packedIndexingMap);
289
- }
290
-
291
- int64_t numInnerLoops = packInfo.getNumTiledLoops ();
292
- SmallVector<utils::IteratorType> iterTypes =
293
- genericOp.getIteratorTypesArray ();
294
- iterTypes.append (numInnerLoops, utils::IteratorType::parallel);
295
-
296
313
// Rebuild the indexing map for the corresponding init operand.
297
314
auto [packedOutOperand, packedOutIndexingMap] =
298
- getOrCreatePackedViewOfOperand (rewriter, loc, packInfo, genericOp,
299
- opOperand);
300
- indexingMaps.push_back (packedOutIndexingMap);
315
+ getOrCreatePackedViewOfOperand (rewriter, genericOp.getLoc (), packInfo,
316
+ genericOp, opOperand);
301
317
302
318
// We'll replace the init operand with the destination of pack op if the init
303
319
// operand has not users in the body of the linalg.generic (pure elementwise).
@@ -306,15 +322,12 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
306
322
Value dest = (genericOp.getRegionOutputArgs ()[0 ].use_empty ())
307
323
? packOp.getDest ()
308
324
: packedOutOperand;
309
- auto newGenericOp = rewriter.create <linalg::GenericOp>(
310
- loc, dest.getType (), inputOperands, dest, indexingMaps, iterTypes,
311
- /* bodyBuild=*/ nullptr , linalg::getPrunedAttributeList (genericOp));
312
- rewriter.cloneRegionBefore (genericOp.getRegion (), newGenericOp.getRegion (),
313
- newGenericOp.getRegion ().begin ());
314
- return newGenericOp;
325
+
326
+ return packElementWiseOp (rewriter, genericOp, dest, packedOutIndexingMap,
327
+ packInfo);
315
328
}
316
329
317
- // Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
330
+ // / Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
318
331
struct BubbleUpPackOpThroughElemGenericOpPattern
319
332
: public OpRewritePattern<tensor::PackOp> {
320
333
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
@@ -328,10 +341,134 @@ struct BubbleUpPackOpThroughElemGenericOpPattern
328
341
return success ();
329
342
}
330
343
};
344
+
345
+ // TODO: Relax this restriction. We should unpack an elementwise also
346
+ // in the presence of multiple unpack ops as producers.
347
+ // / Return the unpacked operand, if present, for the current generic op.
348
+ static FailureOr<OpOperand *> getUnPackedOperand (GenericOp genericOp) {
349
+ OpOperand *unPackedOperand = nullptr ;
350
+ for (OpOperand &operand : genericOp->getOpOperands ()) {
351
+ auto unPackOp = operand.get ().getDefiningOp <tensor::UnPackOp>();
352
+ if (!unPackOp)
353
+ continue ;
354
+ if (unPackedOperand)
355
+ return failure ();
356
+ unPackedOperand = &operand;
357
+ }
358
+ if (!unPackedOperand)
359
+ return failure ();
360
+ return unPackedOperand;
361
+ }
362
+
363
+ // / Push down a tensor.unpack op through elementwise generic op.
364
+ // / The new generic op works on packed domain; pack ops are created for input
365
+ // / and output operands. A tensor.unpack op is inserted right after the packed
366
+ // / generic. E.g.
367
+ // /
368
+ // / #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
369
+ // /
370
+ // / %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
371
+ // /
372
+ // / %0 = tensor.empty() : tensor<12x56x56x64xf32>
373
+ // / %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
374
+ // / inner_dims_pos = [3] inner_tiles = [32] into %0
375
+ // / %2 = linalg.generic {indexing_maps = [#map],
376
+ // / iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
377
+ // / outs(%1 : tensor<12x56x56x64xf32>) {
378
+ // / ^bb0(%out : f32):
379
+ // / linalg.yield %out : f32
380
+ // / } -> tensor<12x56x56x64xf32>
381
+ // /
382
+ // / will be converted to
383
+ // /
384
+ // / #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
385
+ // /
386
+ // / %0 = tensor.empty() : tensor<12x56x56x64xf32>
387
+ // / %1 = linalg.generic {indexing_maps = [#map],
388
+ // / iterator_types = ["parallel", "parallel", "parallel",
389
+ // / "parallel", "parallel"]}
390
+ // / outs(%arg0 : tensor<12x2x56x56x32xf32>) {
391
+ // / ^bb0(%out : f32):
392
+ // / linalg.yield %out : f32
393
+ // / } -> tensor<12x2x56x56x32xf32>
394
+ // / %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
395
+ // / inner_dims_pos = [3] inner_tiles = [32] into %0
396
+ // /
397
+ static FailureOr<std::tuple<GenericOp, Value>>
398
+ pushDownUnPackOpThroughElemGenericOp (RewriterBase &rewriter,
399
+ GenericOp genericOp) {
400
+ if (!isElementwise (genericOp))
401
+ return failure ();
402
+ if (genericOp.getNumResults () != 1 )
403
+ return failure ();
404
+
405
+ // Collect the unPacked operand, if present.
406
+ auto maybeUnPackedOperand = getUnPackedOperand (genericOp);
407
+ if (failed (maybeUnPackedOperand))
408
+ return failure ();
409
+ OpOperand *unPackedOperand = *(maybeUnPackedOperand);
410
+
411
+ // Extract packing information.
412
+ tensor::UnPackOp producerUnPackOp =
413
+ unPackedOperand->get ().getDefiningOp <tensor::UnPackOp>();
414
+ assert (producerUnPackOp && " expect a valid UnPackOp" );
415
+ auto packInfo = getPackingInfoFromOperand (
416
+ genericOp.getMatchingIndexingMap (unPackedOperand), producerUnPackOp);
417
+
418
+ // Rebuild the indexing map for the corresponding init operand.
419
+ auto [packedOutOperand, packedOutIndexingMap] =
420
+ getOrCreatePackedViewOfOperand (rewriter, genericOp.getLoc (), packInfo,
421
+ genericOp, genericOp.getDpsInitOperand (0 ));
422
+
423
+ // If the dps init operand of the generic is a tensor.empty, do not pack it
424
+ // and forward the new tensor.empty as a destination.
425
+ Value dest = packedOutOperand;
426
+ if (auto initTensor = genericOp.getDpsInitOperand (0 )
427
+ ->get ()
428
+ .getDefiningOp <tensor::EmptyOp>()) {
429
+ if (auto packOp = packedOutOperand.getDefiningOp <tensor::PackOp>())
430
+ dest = packOp.getDest ();
431
+ }
432
+
433
+ // Pack the genericOp.
434
+ GenericOp newGenericOp = packElementWiseOp (rewriter, genericOp, dest,
435
+ packedOutIndexingMap, packInfo);
436
+
437
+ auto unPackOp = unPackedOperand->get ().getDefiningOp <tensor::UnPackOp>();
438
+ // Insert an unPackOp right after the packed generic.
439
+ Value unPackOpRes =
440
+ rewriter
441
+ .create <tensor::UnPackOp>(
442
+ genericOp.getLoc (),
443
+ newGenericOp.getTiedOpResult (newGenericOp.getDpsInitOperand (0 )),
444
+ unPackOp.getDest (), producerUnPackOp.getInnerDimsPos (),
445
+ producerUnPackOp.getMixedTiles (),
446
+ producerUnPackOp.getOuterDimsPerm ())
447
+ .getResult ();
448
+
449
+ return std::make_tuple (newGenericOp, unPackOpRes);
450
+ }
451
+
452
+ // Wrapper pattern that applies pushDownUnPackOpThroughElemGenericOp method.
453
+ struct PushDownUnPackOpThroughElemGenericOp
454
+ : public OpRewritePattern<GenericOp> {
455
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
456
+
457
+ LogicalResult matchAndRewrite (GenericOp genericOp,
458
+ PatternRewriter &rewriter) const override {
459
+ auto genericAndRepl =
460
+ pushDownUnPackOpThroughElemGenericOp (rewriter, genericOp);
461
+ if (failed (genericAndRepl))
462
+ return failure ();
463
+ rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
464
+ return success ();
465
+ }
466
+ };
467
+
331
468
} // namespace
332
469
333
470
void mlir::linalg::populateDataLayoutPropagationPatterns (
334
471
RewritePatternSet &patterns) {
335
- patterns.insert <BubbleUpPackOpThroughElemGenericOpPattern>(
336
- patterns.getContext ());
472
+ patterns.insert <BubbleUpPackOpThroughElemGenericOpPattern,
473
+ PushDownUnPackOpThroughElemGenericOp>( patterns.getContext ());
337
474
}
0 commit comments