Skip to content

Commit f5fe92f

Browse files
[mlir][SCF] Fix loop pipelining unable to handle ops with regions
This change allows the SCF LoopPipelining transform to handle ops with nested regions within the pipelined `scf.for` body. The op and nested regions are treated as a single unit from the transform's perspective. This change also makes explicit the requirement that only ops whose parent Block is the loop body Block are allowed to be scheduled by the caller. Reviewed By: ThomasRaoux, nicolasvasilache Differential Revision: https://reviews.llvm.org/D133965
1 parent 07d0ef3 commit f5fe92f

File tree

5 files changed

+437
-82
lines changed

5 files changed

+437
-82
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

+6
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ def ForOp : SCF_Op<"for",
233233
Block::BlockArgListType getRegionIterArgs() {
234234
return getBody()->getArguments().drop_front(getNumInductionVars());
235235
}
236+
/// Return the `index`-th region iteration argument.
237+
BlockArgument getRegionIterArg(unsigned index) {
238+
assert(index < getNumRegionIterArgs() &&
239+
"expected an index less than the number of region iter args");
240+
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
241+
}
236242
Operation::operand_range getIterOperands() {
237243
return getOperands().drop_front(getNumControlOperands());
238244
}

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

+163-74
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/BlockAndValueMapping.h"
1919
#include "mlir/IR/PatternMatch.h"
2020
#include "mlir/Support/MathExtras.h"
21+
#include "mlir/Transforms/RegionUtils.h"
2122
#include "llvm/ADT/MapVector.h"
2223

2324
using namespace mlir;
@@ -114,15 +115,28 @@ bool LoopPipelinerInternal::initializeLoopInfo(
114115
return false;
115116

116117
// All operations need to have a stage.
117-
if (forOp
118-
.walk([this](Operation *op) {
119-
if (op != forOp.getOperation() && !isa<scf::YieldOp>(op) &&
120-
stages.find(op) == stages.end())
121-
return WalkResult::interrupt();
122-
return WalkResult::advance();
123-
})
124-
.wasInterrupted())
125-
return false;
118+
for (Operation &op : forOp.getBody()->without_terminator()) {
119+
if (stages.find(&op) == stages.end()) {
120+
op.emitOpError("not assigned a pipeline stage");
121+
return false;
122+
}
123+
}
124+
125+
// Currently, we do not support assigning stages to ops in nested regions. The
126+
// block of all operations assigned a stage should be the single `scf.for`
127+
// body block.
128+
for (const auto &[op, stageNum] : stages) {
129+
(void)stageNum;
130+
if (op == forOp.getBody()->getTerminator()) {
131+
op->emitError("terminator should not be assigned a stage");
132+
return false;
133+
}
134+
if (op->getBlock() != forOp.getBody()) {
135+
op->emitOpError("the owning Block of all operations assigned a stage "
136+
"should be the loop body block");
137+
return false;
138+
}
139+
}
126140

127141
// Only support loop carried dependency with a distance of 1. This means the
128142
// source of all the scf.yield operands needs to be defined by operations in
@@ -137,6 +151,27 @@ bool LoopPipelinerInternal::initializeLoopInfo(
137151
return true;
138152
}
139153

154+
/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
155+
/// operands of nested ops that:
156+
/// 1) aren't defined within the new op or
157+
/// 2) are block arguments.
158+
static Operation *
159+
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
160+
function_ref<void(OpOperand *newOperand)> callback) {
161+
Operation *clone = rewriter.clone(*op);
162+
for (OpOperand &operand : clone->getOpOperands())
163+
callback(&operand);
164+
clone->walk([&](Operation *nested) {
165+
for (OpOperand &operand : nested->getOpOperands()) {
166+
Operation *def = operand.get().getDefiningOp();
167+
if ((def && !clone->isAncestor(def)) ||
168+
operand.get().isa<BlockArgument>())
169+
callback(&operand);
170+
}
171+
});
172+
return clone;
173+
}
174+
140175
void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
141176
// Initialize the iteration argument to the loop initiale values.
142177
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
@@ -152,12 +187,14 @@ void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
152187
for (Operation *op : opOrder) {
153188
if (stages[op] > i)
154189
continue;
155-
Operation *newOp = rewriter.clone(*op);
156-
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) {
157-
auto it = valueMapping.find(op->getOperand(opIdx));
158-
if (it != valueMapping.end())
159-
newOp->setOperand(opIdx, it->second[i - stages[op]]);
160-
}
190+
Operation *newOp =
191+
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
192+
auto it = valueMapping.find(newOperand->get());
193+
if (it != valueMapping.end()) {
194+
Value replacement = it->second[i - stages[op]];
195+
newOperand->set(replacement);
196+
}
197+
});
161198
if (annotateFn)
162199
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
163200
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -181,18 +218,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
181218
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
182219
for (Operation *op : opOrder) {
183220
unsigned stage = stages[op];
184-
for (OpOperand &operand : op->getOpOperands()) {
221+
222+
auto analyzeOperand = [&](OpOperand &operand) {
185223
Operation *def = operand.get().getDefiningOp();
186224
if (!def)
187-
continue;
225+
return;
188226
auto defStage = stages.find(def);
189227
if (defStage == stages.end() || defStage->second == stage)
190-
continue;
228+
return;
191229
assert(stage > defStage->second);
192230
LiverangeInfo &info = crossStageValues[operand.get()];
193231
info.defStage = defStage->second;
194232
info.lastUseStage = std::max(info.lastUseStage, stage);
195-
}
233+
};
234+
235+
for (OpOperand &operand : op->getOpOperands())
236+
analyzeOperand(operand);
237+
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
238+
analyzeOperand(*operand);
239+
});
196240
}
197241
return crossStageValues;
198242
}
@@ -243,9 +287,89 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
243287
auto newForOp =
244288
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
245289
forOp.getStep(), newLoopArg);
290+
// When there are no iter args, the loop body terminator will be created.
291+
// Since we always create it below, remove the terminator if it was created.
292+
if (!newForOp.getBody()->empty())
293+
rewriter.eraseOp(newForOp.getBody()->getTerminator());
246294
return newForOp;
247295
}
248296

297+
/// Replace any use of `target` with `replacement` in `op`'s operands or within
298+
/// `op`'s nested regions.
299+
static void replaceInOp(Operation *op, Value target, Value replacement) {
300+
for (auto &use : llvm::make_early_inc_range(target.getUses())) {
301+
if (op->isAncestor(use.getOwner()))
302+
use.set(replacement);
303+
}
304+
}
305+
306+
/// Given a cloned op in the new kernel body, updates induction variable uses.
307+
/// We replace it with a version incremented based on the stage where it is
308+
/// used.
309+
static void updateInductionVariableUses(RewriterBase &rewriter, Location loc,
310+
Operation *newOp, Value newForIv,
311+
unsigned maxStage, unsigned useStage,
312+
unsigned step) {
313+
rewriter.setInsertionPoint(newOp);
314+
Value offset = rewriter.create<arith::ConstantIndexOp>(
315+
loc, (maxStage - useStage) * step);
316+
Value iv = rewriter.create<arith::AddIOp>(loc, newForIv, offset);
317+
replaceInOp(newOp, newForIv, iv);
318+
rewriter.setInsertionPointAfter(newOp);
319+
}
320+
321+
/// If the value is a loop carried value coming from stage N + 1 remap, it will
322+
/// become a direct use.
323+
static void updateIterArgUses(RewriterBase &rewriter, BlockAndValueMapping &bvm,
324+
Operation *newOp, ForOp oldForOp, ForOp newForOp,
325+
unsigned useStage,
326+
const DenseMap<Operation *, unsigned> &stages) {
327+
328+
for (unsigned i = 0; i < oldForOp.getNumRegionIterArgs(); i++) {
329+
Value yieldedVal = oldForOp.getBody()->getTerminator()->getOperand(i);
330+
Operation *dep = yieldedVal.getDefiningOp();
331+
if (!dep)
332+
continue;
333+
auto stageDep = stages.find(dep);
334+
if (stageDep == stages.end() || stageDep->second == useStage)
335+
continue;
336+
if (stageDep->second != useStage + 1)
337+
continue;
338+
Value replacement = bvm.lookup(yieldedVal);
339+
replaceInOp(newOp, newForOp.getRegionIterArg(i), replacement);
340+
}
341+
}
342+
343+
/// For operands defined in a previous stage we need to remap it to use the
344+
/// correct region argument. We look for the right version of the Value based
345+
/// on the stage where it is used.
346+
static void updateCrossStageUses(
347+
RewriterBase &rewriter, Operation *newOp, BlockAndValueMapping &bvm,
348+
ForOp newForOp, unsigned useStage,
349+
const DenseMap<Operation *, unsigned> &stages,
350+
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
351+
// Because we automatically cloned the sub-regions, there's no simple way
352+
// to walk the nested regions in pairs of (oldOps, newOps), so we just
353+
// traverse the set of remapped loop arguments, filter which ones are
354+
// relevant, and replace any uses.
355+
for (auto [remapPair, newIterIdx] : loopArgMap) {
356+
auto [crossArgValue, stageIdx] = remapPair;
357+
Operation *def = crossArgValue.getDefiningOp();
358+
assert(def);
359+
unsigned stageDef = stages.lookup(def);
360+
if (useStage <= stageDef || useStage - stageDef != stageIdx)
361+
continue;
362+
363+
// Use "lookupOrDefault" for the target value because some operations
364+
// are remapped, while in other cases the original will be present.
365+
Value target = bvm.lookupOrDefault(crossArgValue);
366+
Value replacement = newForOp.getRegionIterArg(newIterIdx);
367+
368+
// Replace uses in the new op's operands and any nested uses.
369+
replaceInOp(newOp, target, replacement);
370+
}
371+
}
372+
249373
void LoopPipelinerInternal::createKernel(
250374
scf::ForOp newForOp,
251375
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -277,51 +401,17 @@ void LoopPipelinerInternal::createKernel(
277401
for (Operation *op : opOrder) {
278402
int64_t useStage = stages[op];
279403
auto *newOp = rewriter.clone(*op, mapping);
280-
for (OpOperand &operand : op->getOpOperands()) {
281-
// Special case for the induction variable uses. We replace it with a
282-
// version incremented based on the stage where it is used.
283-
if (operand.get() == forOp.getInductionVar()) {
284-
rewriter.setInsertionPoint(newOp);
285-
Value offset = rewriter.create<arith::ConstantIndexOp>(
286-
forOp.getLoc(), (maxStage - stages[op]) * step);
287-
Value iv = rewriter.create<arith::AddIOp>(
288-
forOp.getLoc(), newForOp.getInductionVar(), offset);
289-
newOp->setOperand(operand.getOperandNumber(), iv);
290-
rewriter.setInsertionPointAfter(newOp);
291-
continue;
292-
}
293-
auto arg = operand.get().dyn_cast<BlockArgument>();
294-
if (arg && arg.getOwner() == forOp.getBody()) {
295-
// If the value is a loop carried value coming from stage N + 1 remap,
296-
// it will become a direct use.
297-
Value ret = forOp.getBody()->getTerminator()->getOperand(
298-
arg.getArgNumber() - 1);
299-
Operation *dep = ret.getDefiningOp();
300-
if (!dep)
301-
continue;
302-
auto stageDep = stages.find(dep);
303-
if (stageDep == stages.end() || stageDep->second == useStage)
304-
continue;
305-
assert(stageDep->second == useStage + 1);
306-
newOp->setOperand(operand.getOperandNumber(),
307-
mapping.lookupOrDefault(ret));
308-
continue;
309-
}
310-
// For operands defined in a previous stage we need to remap it to use
311-
// the correct region argument. We look for the right version of the
312-
// Value based on the stage where it is used.
313-
Operation *def = operand.get().getDefiningOp();
314-
if (!def)
315-
continue;
316-
auto stageDef = stages.find(def);
317-
if (stageDef == stages.end() || stageDef->second == useStage)
318-
continue;
319-
auto remap = loopArgMap.find(
320-
std::make_pair(operand.get(), useStage - stageDef->second));
321-
assert(remap != loopArgMap.end());
322-
newOp->setOperand(operand.getOperandNumber(),
323-
newForOp.getRegionIterArgs()[remap->second]);
324-
}
404+
405+
// Within the kernel body, update uses of the induction variable, uses of
406+
// the original iter args, and uses of cross stage values.
407+
updateInductionVariableUses(rewriter, forOp.getLoc(), newOp,
408+
newForOp.getInductionVar(), maxStage,
409+
stages[op], step);
410+
updateIterArgUses(rewriter, mapping, newOp, forOp, newForOp, useStage,
411+
stages);
412+
updateCrossStageUses(rewriter, newOp, mapping, newForOp, useStage, stages,
413+
loopArgMap);
414+
325415
if (predicates[useStage]) {
326416
newOp = predicateFn(newOp, predicates[useStage], rewriter);
327417
// Remap the results to the new predicated one.
@@ -382,21 +472,20 @@ LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
382472
forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i));
383473
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
384474
}
385-
// Emit `maxStage - 1` epilogue part that includes operations fro stages
475+
// Emit `maxStage - 1` epilogue part that includes operations from stages
386476
// [i; maxStage].
387477
for (int64_t i = 1; i <= maxStage; i++) {
388478
for (Operation *op : opOrder) {
389479
if (stages[op] < i)
390480
continue;
391-
Operation *newOp = rewriter.clone(*op);
392-
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) {
393-
auto it = valueMapping.find(op->getOperand(opIdx));
394-
if (it != valueMapping.end()) {
395-
Value v = it->second[maxStage - stages[op] + i];
396-
assert(v);
397-
newOp->setOperand(opIdx, v);
398-
}
399-
}
481+
Operation *newOp =
482+
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
483+
auto it = valueMapping.find(newOperand->get());
484+
if (it != valueMapping.end()) {
485+
Value replacement = it->second[maxStage - stages[op] + i];
486+
newOperand->set(replacement);
487+
}
488+
});
400489
if (annotateFn)
401490
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
402491
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {

0 commit comments

Comments
 (0)