Skip to content

Commit 42ada5f

Browse files
committed
[MLIR] NFC cleanup/modernize memref-dataflow-opt / getNestingDepth
Bring code to date with recent changes to the core infrastructure / coding style. Differential Revision: https://reviews.llvm.org/D77998
1 parent 500e038 commit 42ada5f

File tree

6 files changed

+32
-43
lines changed

6 files changed

+32
-43
lines changed

mlir/include/mlir/Analysis/Utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
4141

4242
/// Returns the nesting depth of this operation, i.e., the number of loops
4343
/// surrounding this operation.
44-
unsigned getNestingDepth(Operation &op);
44+
unsigned getNestingDepth(Operation *op);
4545

4646
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
4747
/// at 'forOp'.

mlir/lib/Analysis/Utils.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,8 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
569569
if (srcAccess.memref != dstAccess.memref)
570570
continue;
571571
// Check if 'loopDepth' exceeds nesting depth of src/dst ops.
572-
if ((!isBackwardSlice && loopDepth > getNestingDepth(*opsA[i])) ||
573-
(isBackwardSlice && loopDepth > getNestingDepth(*opsB[j]))) {
572+
if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
573+
(isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
574574
LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n.");
575575
return failure();
576576
}
@@ -895,8 +895,8 @@ bool MemRefAccess::isStore() const { return isa<AffineStoreOp>(opInst); }
895895

896896
/// Returns the nesting depth of this statement, i.e., the number of loops
897897
/// surrounding this statement.
898-
unsigned mlir::getNestingDepth(Operation &op) {
899-
Operation *currOp = &op;
898+
unsigned mlir::getNestingDepth(Operation *op) {
899+
Operation *currOp = op;
900900
unsigned depth = 0;
901901
while ((currOp = currOp->getParentOp())) {
902902
if (isa<AffineForOp>(currOp))
@@ -957,7 +957,7 @@ static Optional<int64_t> getMemoryFootprintBytes(Block &block,
957957
auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
958958
if (failed(
959959
region->compute(opInst,
960-
/*loopDepth=*/getNestingDepth(*block.begin())))) {
960+
/*loopDepth=*/getNestingDepth(&*block.begin())))) {
961961
return opInst->emitError("error obtaining memory region\n");
962962
}
963963

@@ -1023,7 +1023,7 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
10231023
return false;
10241024

10251025
// Dep check depth would be number of enclosing loops + 1.
1026-
unsigned depth = getNestingDepth(*forOp.getOperation()) + 1;
1026+
unsigned depth = getNestingDepth(forOp) + 1;
10271027

10281028
// Check dependences between all pairs of ops in 'loadAndStoreOpInsts'.
10291029
for (auto *srcOpInst : loadAndStoreOpInsts) {

mlir/lib/Transforms/LoopFusion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1492,7 +1492,7 @@ struct GreedyFusion {
14921492
srcStoreOp = nullptr;
14931493
break;
14941494
}
1495-
unsigned loopDepth = getNestingDepth(*storeOp);
1495+
unsigned loopDepth = getNestingDepth(storeOp);
14961496
if (loopDepth > maxLoopDepth) {
14971497
maxLoopDepth = loopDepth;
14981498
srcStoreOp = storeOp;

mlir/lib/Transforms/MemRefDataFlowOpt.cpp

+22-26
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,18 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createMemRefDataFlowOptPass() {
8585
// This is a straightforward implementation not optimized for speed. Optimize
8686
// if needed.
8787
void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
88-
Operation *loadOpInst = loadOp.getOperation();
89-
90-
// First pass over the use list to get minimum number of surrounding
88+
// First pass over the use list to get the minimum number of surrounding
9189
// loops common between the load op and the store op, with min taken across
9290
// all store ops.
9391
SmallVector<Operation *, 8> storeOps;
94-
unsigned minSurroundingLoops = getNestingDepth(*loadOpInst);
92+
unsigned minSurroundingLoops = getNestingDepth(loadOp);
9593
for (auto *user : loadOp.getMemRef().getUsers()) {
9694
auto storeOp = dyn_cast<AffineStoreOp>(user);
9795
if (!storeOp)
9896
continue;
99-
auto *storeOpInst = storeOp.getOperation();
100-
unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
97+
unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
10198
minSurroundingLoops = std::min(nsLoops, minSurroundingLoops);
102-
storeOps.push_back(storeOpInst);
99+
storeOps.push_back(storeOp);
103100
}
104101

105102
// The list of store op candidates for forwarding that satisfy conditions
@@ -111,12 +108,12 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
111108
// post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
112109
SmallVector<Operation *, 8> depSrcStores;
113110

114-
for (auto *storeOpInst : storeOps) {
115-
MemRefAccess srcAccess(storeOpInst);
116-
MemRefAccess destAccess(loadOpInst);
111+
for (auto *storeOp : storeOps) {
112+
MemRefAccess srcAccess(storeOp);
113+
MemRefAccess destAccess(loadOp);
117114
// Find stores that may be reaching the load.
118115
FlatAffineConstraints dependenceConstraints;
119-
unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
116+
unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
120117
unsigned d;
121118
// Dependences at loop depth <= minSurroundingLoops do NOT matter.
122119
for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
@@ -130,7 +127,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
130127
continue;
131128

132129
// Stores that *may* be reaching the load.
133-
depSrcStores.push_back(storeOpInst);
130+
depSrcStores.push_back(storeOp);
134131

135132
// 1. Check if the store and the load have mathematically equivalent
136133
// affine access functions; this implies that they statically refer to the
@@ -144,11 +141,11 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
144141
continue;
145142

146143
// 2. The store has to dominate the load op to be candidate.
147-
if (!domInfo->dominates(storeOpInst, loadOpInst))
144+
if (!domInfo->dominates(storeOp, loadOp))
148145
continue;
149146

150147
// We now have a candidate for forwarding.
151-
fwdingCandidates.push_back(storeOpInst);
148+
fwdingCandidates.push_back(storeOp);
152149
}
153150

154151
// 3. Of all the store op's that meet the above criteria, the store that
@@ -158,11 +155,11 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
158155
// Note: this can be implemented in a cleaner way with postdominator tree
159156
// traversals. Consider this for the future if needed.
160157
Operation *lastWriteStoreOp = nullptr;
161-
for (auto *storeOpInst : fwdingCandidates) {
158+
for (auto *storeOp : fwdingCandidates) {
162159
if (llvm::all_of(depSrcStores, [&](Operation *depStore) {
163-
return postDomInfo->postDominates(storeOpInst, depStore);
160+
return postDomInfo->postDominates(storeOp, depStore);
164161
})) {
165-
lastWriteStoreOp = storeOpInst;
162+
lastWriteStoreOp = storeOp;
166163
break;
167164
}
168165
}
@@ -175,7 +172,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
175172
// Record the memref for a later sweep to optimize away.
176173
memrefsToErase.insert(loadOp.getMemRef());
177174
// Record this to erase later.
178-
loadOpsToErase.push_back(loadOpInst);
175+
loadOpsToErase.push_back(loadOp);
179176
}
180177

181178
void MemRefDataFlowOpt::runOnFunction() {
@@ -192,32 +189,31 @@ void MemRefDataFlowOpt::runOnFunction() {
192189
loadOpsToErase.clear();
193190
memrefsToErase.clear();
194191

195-
// Walk all load's and perform load/store forwarding.
192+
// Walk all load's and perform store to load forwarding.
196193
f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); });
197194

198195
// Erase all load op's whose results were replaced with store fwd'ed ones.
199-
for (auto *loadOp : loadOpsToErase) {
196+
for (auto *loadOp : loadOpsToErase)
200197
loadOp->erase();
201-
}
202198

203199
// Check if the store fwd'ed memrefs are now left with only stores and can
204200
// thus be completely deleted. Note: the canonicalize pass should be able
205201
// to do this as well, but we'll do it here since we collected these anyway.
206202
for (auto memref : memrefsToErase) {
207203
// If the memref hasn't been alloc'ed in this function, skip.
208-
Operation *defInst = memref.getDefiningOp();
209-
if (!defInst || !isa<AllocOp>(defInst))
204+
Operation *defOp = memref.getDefiningOp();
205+
if (!defOp || !isa<AllocOp>(defOp))
210206
// TODO(mlir-team): if the memref was returned by a 'call' operation, we
211207
// could still erase it if the call had no side-effects.
212208
continue;
213-
if (llvm::any_of(memref.getUsers(), [&](Operation *ownerInst) {
214-
return (!isa<AffineStoreOp>(ownerInst) && !isa<DeallocOp>(ownerInst));
209+
if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
210+
return (!isa<AffineStoreOp>(ownerOp) && !isa<DeallocOp>(ownerOp));
215211
}))
216212
continue;
217213

218214
// Erase all stores, the dealloc, and the alloc on the memref.
219215
for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
220216
user->erase();
221-
defInst->erase();
217+
defOp->erase();
222218
}
223219
}

mlir/lib/Transforms/Utils/LoopUtils.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,7 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
19111911

19121912
// Copies will be generated for this depth, i.e., symbolic in all loops
19131913
// surrounding the this block range.
1914-
unsigned copyDepth = getNestingDepth(*begin);
1914+
unsigned copyDepth = getNestingDepth(&*begin);
19151915

19161916
LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
19171917
<< "\n");

mlir/test/lib/Transforms/TestLoopFusion.cpp

+1-8
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,14 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
#include "mlir/Analysis/AffineAnalysis.h"
14-
#include "mlir/Analysis/AffineStructures.h"
1513
#include "mlir/Analysis/Utils.h"
1614
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1715
#include "mlir/Dialect/StandardOps/IR/Ops.h"
18-
#include "mlir/IR/Builders.h"
1916
#include "mlir/Pass/Pass.h"
2017
#include "mlir/Transforms/LoopFusionUtils.h"
2118
#include "mlir/Transforms/LoopUtils.h"
2219
#include "mlir/Transforms/Passes.h"
2320

24-
#include "llvm/ADT/STLExtras.h"
25-
#include "llvm/Support/CommandLine.h"
26-
#include "llvm/Support/Debug.h"
27-
2821
#define DEBUG_TYPE "test-loop-fusion"
2922

3023
using namespace mlir;
@@ -90,7 +83,7 @@ static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
9083
std::string result;
9184
llvm::raw_string_ostream os(result);
9285
// Slice insertion point format [loop-depth, operation-block-index]
93-
unsigned ipd = getNestingDepth(*sliceUnion.insertPoint);
86+
unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
9487
unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
9588
os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
9689
<< ")";

0 commit comments

Comments
 (0)