18
18
#include " mlir/IR/BlockAndValueMapping.h"
19
19
#include " mlir/IR/PatternMatch.h"
20
20
#include " mlir/Support/MathExtras.h"
21
+ #include " mlir/Transforms/RegionUtils.h"
21
22
#include " llvm/ADT/MapVector.h"
22
23
23
24
using namespace mlir ;
@@ -114,15 +115,28 @@ bool LoopPipelinerInternal::initializeLoopInfo(
114
115
return false ;
115
116
116
117
// All operations need to have a stage.
117
- if (forOp
118
- .walk ([this ](Operation *op) {
119
- if (op != forOp.getOperation () && !isa<scf::YieldOp>(op) &&
120
- stages.find (op) == stages.end ())
121
- return WalkResult::interrupt ();
122
- return WalkResult::advance ();
123
- })
124
- .wasInterrupted ())
125
- return false ;
118
+ for (Operation &op : forOp.getBody ()->without_terminator ()) {
119
+ if (stages.find (&op) == stages.end ()) {
120
+ op.emitOpError (" not assigned a pipeline stage" );
121
+ return false ;
122
+ }
123
+ }
124
+
125
+ // Currently, we do not support assigning stages to ops in nested regions. The
126
+ // block of all operations assigned a stage should be the single `scf.for`
127
+ // body block.
128
+ for (const auto &[op, stageNum] : stages) {
129
+ (void )stageNum;
130
+ if (op == forOp.getBody ()->getTerminator ()) {
131
+ op->emitError (" terminator should not be assigned a stage" );
132
+ return false ;
133
+ }
134
+ if (op->getBlock () != forOp.getBody ()) {
135
+ op->emitOpError (" the owning Block of all operations assigned a stage "
136
+ " should be the loop body block" );
137
+ return false ;
138
+ }
139
+ }
126
140
127
141
// Only support loop carried dependency with a distance of 1. This means the
128
142
// source of all the scf.yield operands needs to be defined by operations in
@@ -137,6 +151,27 @@ bool LoopPipelinerInternal::initializeLoopInfo(
137
151
return true ;
138
152
}
139
153
154
+ // / Clone `op` and call `callback` on the cloned op's oeprands as well as any
155
+ // / operands of nested ops that:
156
+ // / 1) aren't defined within the new op or
157
+ // / 2) are block arguments.
158
+ static Operation *
159
+ cloneAndUpdateOperands (RewriterBase &rewriter, Operation *op,
160
+ function_ref<void (OpOperand *newOperand)> callback) {
161
+ Operation *clone = rewriter.clone (*op);
162
+ for (OpOperand &operand : clone->getOpOperands ())
163
+ callback (&operand);
164
+ clone->walk ([&](Operation *nested) {
165
+ for (OpOperand &operand : nested->getOpOperands ()) {
166
+ Operation *def = operand.get ().getDefiningOp ();
167
+ if ((def && !clone->isAncestor (def)) ||
168
+ operand.get ().isa <BlockArgument>())
169
+ callback (&operand);
170
+ }
171
+ });
172
+ return clone;
173
+ }
174
+
140
175
void LoopPipelinerInternal::emitPrologue (PatternRewriter &rewriter) {
141
176
// Initialize the iteration argument to the loop initiale values.
142
177
for (BlockArgument &arg : forOp.getRegionIterArgs ()) {
@@ -152,12 +187,14 @@ void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
152
187
for (Operation *op : opOrder) {
153
188
if (stages[op] > i)
154
189
continue ;
155
- Operation *newOp = rewriter.clone (*op);
156
- for (unsigned opIdx = 0 ; opIdx < op->getNumOperands (); opIdx++) {
157
- auto it = valueMapping.find (op->getOperand (opIdx));
158
- if (it != valueMapping.end ())
159
- newOp->setOperand (opIdx, it->second [i - stages[op]]);
160
- }
190
+ Operation *newOp =
191
+ cloneAndUpdateOperands (rewriter, op, [&](OpOperand *newOperand) {
192
+ auto it = valueMapping.find (newOperand->get ());
193
+ if (it != valueMapping.end ()) {
194
+ Value replacement = it->second [i - stages[op]];
195
+ newOperand->set (replacement);
196
+ }
197
+ });
161
198
if (annotateFn)
162
199
annotateFn (newOp, PipeliningOption::PipelinerPart::Prologue, i);
163
200
for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
@@ -181,18 +218,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
181
218
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
182
219
for (Operation *op : opOrder) {
183
220
unsigned stage = stages[op];
184
- for (OpOperand &operand : op->getOpOperands ()) {
221
+
222
+ auto analyzeOperand = [&](OpOperand &operand) {
185
223
Operation *def = operand.get ().getDefiningOp ();
186
224
if (!def)
187
- continue ;
225
+ return ;
188
226
auto defStage = stages.find (def);
189
227
if (defStage == stages.end () || defStage->second == stage)
190
- continue ;
228
+ return ;
191
229
assert (stage > defStage->second );
192
230
LiverangeInfo &info = crossStageValues[operand.get ()];
193
231
info.defStage = defStage->second ;
194
232
info.lastUseStage = std::max (info.lastUseStage , stage);
195
- }
233
+ };
234
+
235
+ for (OpOperand &operand : op->getOpOperands ())
236
+ analyzeOperand (operand);
237
+ visitUsedValuesDefinedAbove (op->getRegions (), [&](OpOperand *operand) {
238
+ analyzeOperand (*operand);
239
+ });
196
240
}
197
241
return crossStageValues;
198
242
}
@@ -243,9 +287,89 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
243
287
auto newForOp =
244
288
rewriter.create <scf::ForOp>(forOp.getLoc (), forOp.getLowerBound (), newUb,
245
289
forOp.getStep (), newLoopArg);
290
+ // When there are no iter args, the loop body terminator will be created.
291
+ // Since we always create it below, remove the terminator if it was created.
292
+ if (!newForOp.getBody ()->empty ())
293
+ rewriter.eraseOp (newForOp.getBody ()->getTerminator ());
246
294
return newForOp;
247
295
}
248
296
297
+ // / Replace any use of `target` with `replacement` in `op`'s operands or within
298
+ // / `op`'s nested regions.
299
+ static void replaceInOp (Operation *op, Value target, Value replacement) {
300
+ for (auto &use : llvm::make_early_inc_range (target.getUses ())) {
301
+ if (op->isAncestor (use.getOwner ()))
302
+ use.set (replacement);
303
+ }
304
+ }
305
+
306
+ // / Given a cloned op in the new kernel body, updates induction variable uses.
307
+ // / We replace it with a version incremented based on the stage where it is
308
+ // / used.
309
+ static void updateInductionVariableUses (RewriterBase &rewriter, Location loc,
310
+ Operation *newOp, Value newForIv,
311
+ unsigned maxStage, unsigned useStage,
312
+ unsigned step) {
313
+ rewriter.setInsertionPoint (newOp);
314
+ Value offset = rewriter.create <arith::ConstantIndexOp>(
315
+ loc, (maxStage - useStage) * step);
316
+ Value iv = rewriter.create <arith::AddIOp>(loc, newForIv, offset);
317
+ replaceInOp (newOp, newForIv, iv);
318
+ rewriter.setInsertionPointAfter (newOp);
319
+ }
320
+
321
+ // / If the value is a loop carried value coming from stage N + 1 remap, it will
322
+ // / become a direct use.
323
+ static void updateIterArgUses (RewriterBase &rewriter, BlockAndValueMapping &bvm,
324
+ Operation *newOp, ForOp oldForOp, ForOp newForOp,
325
+ unsigned useStage,
326
+ const DenseMap<Operation *, unsigned > &stages) {
327
+
328
+ for (unsigned i = 0 ; i < oldForOp.getNumRegionIterArgs (); i++) {
329
+ Value yieldedVal = oldForOp.getBody ()->getTerminator ()->getOperand (i);
330
+ Operation *dep = yieldedVal.getDefiningOp ();
331
+ if (!dep)
332
+ continue ;
333
+ auto stageDep = stages.find (dep);
334
+ if (stageDep == stages.end () || stageDep->second == useStage)
335
+ continue ;
336
+ if (stageDep->second != useStage + 1 )
337
+ continue ;
338
+ Value replacement = bvm.lookup (yieldedVal);
339
+ replaceInOp (newOp, newForOp.getRegionIterArg (i), replacement);
340
+ }
341
+ }
342
+
343
+ // / For operands defined in a previous stage we need to remap it to use the
344
+ // / correct region argument. We look for the right version of the Value based
345
+ // / on the stage where it is used.
346
+ static void updateCrossStageUses (
347
+ RewriterBase &rewriter, Operation *newOp, BlockAndValueMapping &bvm,
348
+ ForOp newForOp, unsigned useStage,
349
+ const DenseMap<Operation *, unsigned > &stages,
350
+ const llvm::DenseMap<std::pair<Value, unsigned >, unsigned > &loopArgMap) {
351
+ // Because we automatically cloned the sub-regions, there's no simple way
352
+ // to walk the nested regions in pairs of (oldOps, newOps), so we just
353
+ // traverse the set of remapped loop arguments, filter which ones are
354
+ // relevant, and replace any uses.
355
+ for (auto [remapPair, newIterIdx] : loopArgMap) {
356
+ auto [crossArgValue, stageIdx] = remapPair;
357
+ Operation *def = crossArgValue.getDefiningOp ();
358
+ assert (def);
359
+ unsigned stageDef = stages.lookup (def);
360
+ if (useStage <= stageDef || useStage - stageDef != stageIdx)
361
+ continue ;
362
+
363
+ // Use "lookupOrDefault" for the target value because some operations
364
+ // are remapped, while in other cases the original will be present.
365
+ Value target = bvm.lookupOrDefault (crossArgValue);
366
+ Value replacement = newForOp.getRegionIterArg (newIterIdx);
367
+
368
+ // Replace uses in the new op's operands and any nested uses.
369
+ replaceInOp (newOp, target, replacement);
370
+ }
371
+ }
372
+
249
373
void LoopPipelinerInternal::createKernel (
250
374
scf::ForOp newForOp,
251
375
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -277,51 +401,17 @@ void LoopPipelinerInternal::createKernel(
277
401
for (Operation *op : opOrder) {
278
402
int64_t useStage = stages[op];
279
403
auto *newOp = rewriter.clone (*op, mapping);
280
- for (OpOperand &operand : op->getOpOperands ()) {
281
- // Special case for the induction variable uses. We replace it with a
282
- // version incremented based on the stage where it is used.
283
- if (operand.get () == forOp.getInductionVar ()) {
284
- rewriter.setInsertionPoint (newOp);
285
- Value offset = rewriter.create <arith::ConstantIndexOp>(
286
- forOp.getLoc (), (maxStage - stages[op]) * step);
287
- Value iv = rewriter.create <arith::AddIOp>(
288
- forOp.getLoc (), newForOp.getInductionVar (), offset);
289
- newOp->setOperand (operand.getOperandNumber (), iv);
290
- rewriter.setInsertionPointAfter (newOp);
291
- continue ;
292
- }
293
- auto arg = operand.get ().dyn_cast <BlockArgument>();
294
- if (arg && arg.getOwner () == forOp.getBody ()) {
295
- // If the value is a loop carried value coming from stage N + 1 remap,
296
- // it will become a direct use.
297
- Value ret = forOp.getBody ()->getTerminator ()->getOperand (
298
- arg.getArgNumber () - 1 );
299
- Operation *dep = ret.getDefiningOp ();
300
- if (!dep)
301
- continue ;
302
- auto stageDep = stages.find (dep);
303
- if (stageDep == stages.end () || stageDep->second == useStage)
304
- continue ;
305
- assert (stageDep->second == useStage + 1 );
306
- newOp->setOperand (operand.getOperandNumber (),
307
- mapping.lookupOrDefault (ret));
308
- continue ;
309
- }
310
- // For operands defined in a previous stage we need to remap it to use
311
- // the correct region argument. We look for the right version of the
312
- // Value based on the stage where it is used.
313
- Operation *def = operand.get ().getDefiningOp ();
314
- if (!def)
315
- continue ;
316
- auto stageDef = stages.find (def);
317
- if (stageDef == stages.end () || stageDef->second == useStage)
318
- continue ;
319
- auto remap = loopArgMap.find (
320
- std::make_pair (operand.get (), useStage - stageDef->second ));
321
- assert (remap != loopArgMap.end ());
322
- newOp->setOperand (operand.getOperandNumber (),
323
- newForOp.getRegionIterArgs ()[remap->second ]);
324
- }
404
+
405
+ // Within the kernel body, update uses of the induction variable, uses of
406
+ // the original iter args, and uses of cross stage values.
407
+ updateInductionVariableUses (rewriter, forOp.getLoc (), newOp,
408
+ newForOp.getInductionVar (), maxStage,
409
+ stages[op], step);
410
+ updateIterArgUses (rewriter, mapping, newOp, forOp, newForOp, useStage,
411
+ stages);
412
+ updateCrossStageUses (rewriter, newOp, mapping, newForOp, useStage, stages,
413
+ loopArgMap);
414
+
325
415
if (predicates[useStage]) {
326
416
newOp = predicateFn (newOp, predicates[useStage], rewriter);
327
417
// Remap the results to the new predicated one.
@@ -382,21 +472,20 @@ LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
382
472
forOp.getLoc (), lb + step * ((((ub - 1 ) - lb) / step) - i));
383
473
setValueMapping (forOp.getInductionVar (), newlastIter, maxStage - i);
384
474
}
385
- // Emit `maxStage - 1` epilogue part that includes operations fro stages
475
+ // Emit `maxStage - 1` epilogue part that includes operations from stages
386
476
// [i; maxStage].
387
477
for (int64_t i = 1 ; i <= maxStage; i++) {
388
478
for (Operation *op : opOrder) {
389
479
if (stages[op] < i)
390
480
continue ;
391
- Operation *newOp = rewriter.clone (*op);
392
- for (unsigned opIdx = 0 ; opIdx < op->getNumOperands (); opIdx++) {
393
- auto it = valueMapping.find (op->getOperand (opIdx));
394
- if (it != valueMapping.end ()) {
395
- Value v = it->second [maxStage - stages[op] + i];
396
- assert (v);
397
- newOp->setOperand (opIdx, v);
398
- }
399
- }
481
+ Operation *newOp =
482
+ cloneAndUpdateOperands (rewriter, op, [&](OpOperand *newOperand) {
483
+ auto it = valueMapping.find (newOperand->get ());
484
+ if (it != valueMapping.end ()) {
485
+ Value replacement = it->second [maxStage - stages[op] + i];
486
+ newOperand->set (replacement);
487
+ }
488
+ });
400
489
if (annotateFn)
401
490
annotateFn (newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1 );
402
491
for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
0 commit comments