Skip to content

Commit 0bda492

Browse files
authored
[MLIR] Cache symbol tables during OneShotBufferization analyses (llvm#138125)
During bufferization, the callee of each `func::CallOp` / `CallableOpInterface` operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of the `ModuleOp` body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols. The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/ This patch aims to partially address this scaling issue by leveraging the `SymbolTableCollection` class, whose instance is added to the `FuncAnalysisState` extension. Later modifications are also expected to address the problem in other methods required by `BufferizableOpInterface` (e.g., `bufferize` and `getBufferType`), which suffer of the same problem but do not provide access to any bufferization state.
1 parent cf16c97 commit 0bda492

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
6969
/// analyzed.
7070
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
7171

72+
/// A collection of cached SymbolTables used for faster function lookup.
73+
mutable SymbolTableCollection symbolTables;
74+
7275
/// This function is called right before analyzing the given FuncOp. It
7376
/// initializes the data structures for the FuncOp in this state object.
7477
void startFunctionAnalysis(FuncOp funcOp);

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,29 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
7676
}
7777

7878
/// Return the FuncOp called by `callOp`.
79-
static FuncOp getCalledFunction(CallOpInterface callOp) {
79+
static FuncOp getCalledFunction(CallOpInterface callOp,
80+
SymbolTableCollection &symbolTables) {
8081
SymbolRefAttr sym =
8182
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
8283
if (!sym)
8384
return nullptr;
8485
return dyn_cast_or_null<FuncOp>(
85-
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
86+
symbolTables.lookupNearestSymbolFrom(callOp, sym));
87+
}
88+
89+
/// Return the FuncOp called by `callOp`.
90+
static FuncOp getCalledFunction(CallOpInterface callOp,
91+
const AnalysisState &state) {
92+
auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
93+
94+
if (auto *funcAnalysisState =
95+
oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
96+
// Use the cached symbol tables.
97+
return getCalledFunction(callOp, funcAnalysisState->symbolTables);
98+
}
99+
100+
SymbolTableCollection symbolTables;
101+
return getCalledFunction(callOp, symbolTables);
86102
}
87103

88104
/// Get FuncAnalysisState.
@@ -135,7 +151,7 @@ struct CallOpInterface
135151
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
136152
const AnalysisState &state) const {
137153
func::CallOp callOp = cast<func::CallOp>(op);
138-
FuncOp funcOp = getCalledFunction(callOp);
154+
FuncOp funcOp = getCalledFunction(callOp, state);
139155
assert(funcOp && "expected CallOp to a FuncOp");
140156

141157
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
@@ -150,7 +166,7 @@ struct CallOpInterface
150166
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
151167
const AnalysisState &state) const {
152168
func::CallOp callOp = cast<func::CallOp>(op);
153-
FuncOp funcOp = getCalledFunction(callOp);
169+
FuncOp funcOp = getCalledFunction(callOp, state);
154170
assert(funcOp && "expected CallOp to a FuncOp");
155171

156172
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
@@ -165,7 +181,7 @@ struct CallOpInterface
165181
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
166182
const AnalysisState &state) const {
167183
func::CallOp callOp = cast<func::CallOp>(op);
168-
FuncOp funcOp = getCalledFunction(callOp);
184+
FuncOp funcOp = getCalledFunction(callOp, state);
169185
assert(funcOp && "expected CallOp to a FuncOp");
170186
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
171187
// FuncOp not analyzed yet. Any OpResult may be aliasing.
@@ -199,7 +215,11 @@ struct CallOpInterface
199215
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
200216
SmallVector<Value> &invocationStack) const {
201217
auto callOp = cast<func::CallOp>(op);
202-
FuncOp funcOp = getCalledFunction(callOp);
218+
219+
// TODO Avoid recomputing the symbol tables every time.
220+
SymbolTableCollection symbolTable;
221+
222+
FuncOp funcOp = getCalledFunction(callOp, symbolTable);
203223
assert(funcOp && "expected CallOp to a FuncOp");
204224

205225
// If the callee was already bufferized, we can directly take the type from
@@ -243,7 +263,11 @@ struct CallOpInterface
243263
// 2. Rewrite tensor operands as memrefs based on type of the already
244264
// bufferized callee.
245265
SmallVector<Value> newOperands;
246-
FuncOp funcOp = getCalledFunction(callOp);
266+
267+
// TODO Avoid recomputing the symbol tables every time.
268+
SymbolTableCollection symbolTable;
269+
270+
FuncOp funcOp = getCalledFunction(callOp, symbolTable);
247271
assert(funcOp && "expected CallOp to a FuncOp");
248272
FunctionType funcType = funcOp.getFunctionType();
249273

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
280280
}
281281

282282
/// Return the func::FuncOp called by `callOp`.
283-
static func::FuncOp getCalledFunction(func::CallOp callOp) {
283+
static func::FuncOp
284+
getCalledFunction(func::CallOp callOp,
285+
mlir::SymbolTableCollection &symbolTable) {
284286
SymbolRefAttr sym =
285287
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
286288
if (!sym)
287289
return nullptr;
288290
return dyn_cast_or_null<func::FuncOp>(
289-
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
291+
symbolTable.lookupNearestSymbolFrom(callOp, sym));
290292
}
291293

292294
/// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
314316
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
315317
// For each FuncOp, the number of func::CallOp it contains.
316318
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
319+
320+
// TODO Avoid recomputing the symbol tables every time.
321+
mlir::SymbolTableCollection symbolTable;
322+
317323
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
318324
// Collect function calls and populate the caller map.
319325
numberCallOpsContainedInFuncOp[funcOp] = 0;
320326
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
321-
func::FuncOp calledFunction = getCalledFunction(callOp);
327+
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
322328
assert(calledFunction && "could not retrieved called func::FuncOp");
323329
// If the called function does not have any tensors in its signature, then
324330
// it is not necessary to bufferize the callee before the caller.

0 commit comments

Comments
 (0)