Skip to content

Commit 06bbbf1

Browse files
authored
[MLIR] Cyclic AttrType Replacer (#98206)
The current `AttrTypeReplacer` does not allow for custom handling of replacer functions that may cause self-recursion. For example, the replacement of one attr/type may depend on the replacement of another attr/type (by calling into the replacer manually again), which in turn may depend on the replacement of the original attr/type. To enable this functionality, this PR broke out the original AttrTypeReplacer into two parts: - An uncached base version (`detail::AttrTypeReplacerBase`) that allows registering replacer functions and has logic for invoking it on attr/types & their sub-elements - A cached version (`AttrTypeReplacer`) that provides the same caching as the original one. This is still the one used everywhere and behavior is unchanged. On top of the uncached base version, a `CyclicAttrTypeReplacer` is introduced that provides caching & cycle-handling for replacer logic that is cyclic. Cycle-breaking & caching is provided by the `CyclicReplacerCache` from #98202. Both concrete implementations of the uncached base version use CRTP to avoid dynamic dispatch. The base class merely provides replacer registration & invocation, and is not meant to be used, or otherwise extended elsewhere.
1 parent 026566a commit 06bbbf1

File tree

4 files changed

+467
-49
lines changed

4 files changed

+467
-49
lines changed

mlir/include/mlir/IR/AttrTypeSubElements.h

Lines changed: 119 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/IR/MLIRContext.h"
1818
#include "mlir/IR/Visitors.h"
19+
#include "mlir/Support/CyclicReplacerCache.h"
1920
#include "llvm/ADT/ArrayRef.h"
2021
#include "llvm/ADT/DenseMap.h"
2122
#include <optional>
@@ -116,9 +117,21 @@ class AttrTypeWalker {
116117
/// AttrTypeReplacer
117118
//===----------------------------------------------------------------------===//
118119

119-
/// This class provides a utility for replacing attributes/types, and their sub
120-
/// elements. Multiple replacement functions may be registered.
121-
class AttrTypeReplacer {
120+
namespace detail {
121+
122+
/// This class provides a base utility for replacing attributes/types, and their
123+
/// sub elements. Multiple replacement functions may be registered.
124+
///
125+
/// This base utility is uncached. Users can choose between two cached versions
126+
/// of this replacer:
127+
/// * For non-cyclic replacer logic, use `AttrTypeReplacer`.
128+
/// * For cyclic replacer logic, use `CyclicAttrTypeReplacer`.
129+
///
130+
/// Concrete implementations implement the following `replace` entry functions:
131+
/// * Attribute replace(Attribute attr);
132+
/// * Type replace(Type type);
133+
template <typename Concrete>
134+
class AttrTypeReplacerBase {
122135
public:
123136
//===--------------------------------------------------------------------===//
124137
// Application
@@ -139,12 +152,6 @@ class AttrTypeReplacer {
139152
bool replaceLocs = false,
140153
bool replaceTypes = false);
141154

142-
/// Replace the given attribute/type, and recursively replace any sub
143-
/// elements. Returns either the new attribute/type, or nullptr in the case of
144-
/// failure.
145-
Attribute replace(Attribute attr);
146-
Type replace(Type type);
147-
148155
//===--------------------------------------------------------------------===//
149156
// Registration
150157
//===--------------------------------------------------------------------===//
@@ -206,21 +213,114 @@ class AttrTypeReplacer {
206213
});
207214
}
208215

209-
private:
210-
/// Internal implementation of the `replace` methods above.
211-
template <typename T, typename ReplaceFns>
212-
T replaceImpl(T element, ReplaceFns &replaceFns);
213-
214-
/// Replace the sub elements of the given interface.
215-
template <typename T>
216-
T replaceSubElements(T interface);
216+
protected:
217+
/// Invokes the registered replacement functions from most recently registered
218+
/// to least recently registered until a successful replacement is returned.
219+
/// Unless skipping is requested, invokes `replace` on sub-elements of the
220+
/// current attr/type.
221+
Attribute replaceBase(Attribute attr);
222+
Type replaceBase(Type type);
217223

224+
private:
218225
/// The set of replacement functions that map sub elements.
219226
std::vector<ReplaceFn<Attribute>> attrReplacementFns;
220227
std::vector<ReplaceFn<Type>> typeReplacementFns;
228+
};
229+
230+
} // namespace detail
231+
232+
/// This is an attribute/type replacer that is naively cached. It is best used
233+
/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any
234+
/// re-occurrence of an in-progress element will be skipped.
235+
class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> {
236+
public:
237+
Attribute replace(Attribute attr);
238+
Type replace(Type type);
239+
240+
private:
241+
/// Shared concrete implementation of the public `replace` functions. Invokes
242+
/// `replaceBase` with caching.
243+
template <typename T>
244+
T cachedReplaceImpl(T element);
245+
246+
// Stores the opaque pointer of an attribute or type.
247+
DenseMap<const void *, const void *> cache;
248+
};
249+
250+
/// This is an attribute/type replacer that supports custom handling of cycles
251+
/// in the replacer logic. In addition to registering replacer functions, it
252+
/// allows registering cycle-breaking functions in the same style.
253+
class CyclicAttrTypeReplacer
254+
: public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> {
255+
public:
256+
CyclicAttrTypeReplacer();
221257

222-
/// The set of cached mappings for attributes/types.
223-
DenseMap<const void *, const void *> attrTypeMap;
258+
//===--------------------------------------------------------------------===//
259+
// Application
260+
//===--------------------------------------------------------------------===//
261+
262+
Attribute replace(Attribute attr);
263+
Type replace(Type type);
264+
265+
//===--------------------------------------------------------------------===//
266+
// Registration
267+
//===--------------------------------------------------------------------===//
268+
269+
/// A cycle-breaking function. This is invoked if the same element is asked to
270+
/// be replaced again when the first instance of it is still being replaced.
271+
/// This function must not perform any more recursive `replace` calls.
272+
/// If it is able to break the cycle, it should return a replacement result.
273+
/// Otherwise, it can return std::nullopt to defer cycle breaking to the next
274+
/// repeated element. However, the user must guarantee that, in any possible
275+
/// cycle, there always exists at least one element that can break the cycle.
276+
template <typename T>
277+
using CycleBreakerFn = std::function<std::optional<T>(T)>;
278+
279+
/// Register a cycle-breaking function.
280+
/// When breaking cycles, the mostly recently added cycle-breaking functions
281+
/// will be invoked first.
282+
void addCycleBreaker(CycleBreakerFn<Attribute> fn);
283+
void addCycleBreaker(CycleBreakerFn<Type> fn);
284+
285+
/// Register a cycle-breaking function that doesn't match the default
286+
/// signature.
287+
template <typename FnT,
288+
typename T = typename llvm::function_traits<
289+
std::decay_t<FnT>>::template arg_t<0>,
290+
typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
291+
Attribute, Type>>
292+
std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) {
293+
addCycleBreaker([callback = std::forward<FnT>(callback)](
294+
BaseT base) -> std::optional<BaseT> {
295+
if (auto derived = dyn_cast<T>(base))
296+
return callback(derived);
297+
return std::nullopt;
298+
});
299+
}
300+
301+
private:
302+
/// Invokes the registered cycle-breaker functions from most recently
303+
/// registered to least recently registered until a successful result is
304+
/// returned.
305+
std::optional<const void *> breakCycleImpl(void *element);
306+
307+
/// Shared concrete implementation of the public `replace` functions.
308+
template <typename T>
309+
T cachedReplaceImpl(T element);
310+
311+
/// The set of registered cycle-breaker functions.
312+
std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns;
313+
std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns;
314+
315+
/// A cache of previously-replaced attr/types.
316+
/// The key of the cache is the opaque value of an AttrOrType. Using
317+
/// AttrOrType allows distinguishing between the two types when invoking
318+
/// cycle-breakers. Using its opaque value avoids the cyclic dependency issue
319+
/// of directly using `AttrOrType` to instantiate the cache.
320+
/// The value of the cache is just the opaque value of the attr/type itself
321+
/// (not the PointerUnion).
322+
using AttrOrType = PointerUnion<Attribute, Type>;
323+
CyclicReplacerCache<void *, const void *> cache;
224324
};
225325

226326
//===----------------------------------------------------------------------===//

mlir/lib/IR/AttrTypeSubElements.cpp

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
6767
}
6868

6969
//===----------------------------------------------------------------------===//
70-
/// AttrTypeReplacer
70+
/// AttrTypeReplacerBase
7171
//===----------------------------------------------------------------------===//
7272

73-
void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
73+
template <typename Concrete>
74+
void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
75+
ReplaceFn<Attribute> fn) {
7476
attrReplacementFns.emplace_back(std::move(fn));
7577
}
76-
void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {
78+
79+
template <typename Concrete>
80+
void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
81+
ReplaceFn<Type> fn) {
7782
typeReplacementFns.push_back(std::move(fn));
7883
}
7984

80-
void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
81-
bool replaceLocs, bool replaceTypes) {
85+
template <typename Concrete>
86+
void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn(
87+
Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
8288
// Functor that replaces the given element if the new value is different,
8389
// otherwise returns nullptr.
8490
auto replaceIfDifferent = [&](auto element) {
85-
auto replacement = replace(element);
91+
auto replacement = static_cast<Concrete *>(this)->replace(element);
8692
return (replacement && replacement != element) ? replacement : nullptr;
8793
};
8894

@@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
127133
}
128134
}
129135

130-
void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
131-
bool replaceAttrs,
132-
bool replaceLocs,
133-
bool replaceTypes) {
136+
template <typename Concrete>
137+
void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn(
138+
Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
134139
op->walk([&](Operation *nestedOp) {
135140
replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
136141
});
137142
}
138143

139-
template <typename T>
140-
static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
144+
template <typename T, typename Replacer>
145+
static void updateSubElementImpl(T element, Replacer &replacer,
141146
SmallVectorImpl<T> &newElements,
142147
FailureOr<bool> &changed) {
143148
// Bail early if we failed at any point.
@@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
160165
}
161166
}
162167

163-
template <typename T>
164-
T AttrTypeReplacer::replaceSubElements(T interface) {
168+
template <typename T, typename Replacer>
169+
static T replaceSubElements(T interface, Replacer &replacer) {
165170
// Walk the current sub-elements, replacing them as necessary.
166171
SmallVector<Attribute, 16> newAttrs;
167172
SmallVector<Type, 16> newTypes;
168173
FailureOr<bool> changed = false;
169174
interface.walkImmediateSubElements(
170175
[&](Attribute element) {
171-
updateSubElementImpl(element, *this, newAttrs, changed);
176+
updateSubElementImpl(element, replacer, newAttrs, changed);
172177
},
173178
[&](Type element) {
174-
updateSubElementImpl(element, *this, newTypes, changed);
179+
updateSubElementImpl(element, replacer, newTypes, changed);
175180
});
176181
if (failed(changed))
177182
return nullptr;
@@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) {
184189
}
185190

186191
/// Shared implementation of replacing a given attribute or type element.
187-
template <typename T, typename ReplaceFns>
188-
T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
189-
const void *opaqueElement = element.getAsOpaquePointer();
190-
auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
191-
if (!inserted)
192-
return T::getFromOpaquePointer(it->second);
193-
192+
template <typename T, typename ReplaceFns, typename Replacer>
193+
static T replaceElementImpl(T element, ReplaceFns &replaceFns,
194+
Replacer &replacer) {
194195
T result = element;
195196
WalkResult walkResult = WalkResult::advance();
196197
for (auto &replaceFn : llvm::reverse(replaceFns)) {
@@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
202203

203204
// If an error occurred, return nullptr to indicate failure.
204205
if (walkResult.wasInterrupted() || !result) {
205-
attrTypeMap[opaqueElement] = nullptr;
206206
return nullptr;
207207
}
208208

209209
// Handle replacing sub-elements if this element is also a container.
210210
if (!walkResult.wasSkipped()) {
211211
// Replace the sub elements of this element, bailing if we fail.
212-
if (!(result = replaceSubElements(result))) {
213-
attrTypeMap[opaqueElement] = nullptr;
212+
if (!(result = replaceSubElements(result, replacer))) {
214213
return nullptr;
215214
}
216215
}
217216

218-
attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
217+
return result;
218+
}
219+
220+
template <typename Concrete>
221+
Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) {
222+
return replaceElementImpl(attr, attrReplacementFns,
223+
*static_cast<Concrete *>(this));
224+
}
225+
226+
template <typename Concrete>
227+
Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) {
228+
return replaceElementImpl(type, typeReplacementFns,
229+
*static_cast<Concrete *>(this));
230+
}
231+
232+
//===----------------------------------------------------------------------===//
233+
/// AttrTypeReplacer
234+
//===----------------------------------------------------------------------===//
235+
236+
template class detail::AttrTypeReplacerBase<AttrTypeReplacer>;
237+
238+
template <typename T>
239+
T AttrTypeReplacer::cachedReplaceImpl(T element) {
240+
const void *opaqueElement = element.getAsOpaquePointer();
241+
auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
242+
if (!inserted)
243+
return T::getFromOpaquePointer(it->second);
244+
245+
T result = replaceBase(element);
246+
247+
cache[opaqueElement] = result.getAsOpaquePointer();
219248
return result;
220249
}
221250

222251
Attribute AttrTypeReplacer::replace(Attribute attr) {
223-
return replaceImpl(attr, attrReplacementFns);
252+
return cachedReplaceImpl(attr);
224253
}
225254

226-
Type AttrTypeReplacer::replace(Type type) {
227-
return replaceImpl(type, typeReplacementFns);
255+
Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }
256+
257+
//===----------------------------------------------------------------------===//
258+
/// CyclicAttrTypeReplacer
259+
//===----------------------------------------------------------------------===//
260+
261+
template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>;
262+
263+
CyclicAttrTypeReplacer::CyclicAttrTypeReplacer()
264+
: cache([&](void *attr) { return breakCycleImpl(attr); }) {}
265+
266+
void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) {
267+
attrCycleBreakerFns.emplace_back(std::move(fn));
268+
}
269+
270+
void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) {
271+
typeCycleBreakerFns.emplace_back(std::move(fn));
272+
}
273+
274+
template <typename T>
275+
T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
276+
void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
277+
CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry =
278+
cache.lookupOrInit(opaqueTaggedElement);
279+
if (auto resultOpt = cacheEntry.get())
280+
return T::getFromOpaquePointer(*resultOpt);
281+
282+
T result = replaceBase(element);
283+
284+
cacheEntry.resolve(result.getAsOpaquePointer());
285+
return result;
286+
}
287+
288+
Attribute CyclicAttrTypeReplacer::replace(Attribute attr) {
289+
return cachedReplaceImpl(attr);
290+
}
291+
292+
Type CyclicAttrTypeReplacer::replace(Type type) {
293+
return cachedReplaceImpl(type);
294+
}
295+
296+
std::optional<const void *>
297+
CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
298+
AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
299+
if (auto attr = dyn_cast<Attribute>(attrType)) {
300+
for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
301+
if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
302+
return newRes->getAsOpaquePointer();
303+
}
304+
}
305+
} else {
306+
auto type = dyn_cast<Type>(attrType);
307+
for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
308+
if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
309+
return newRes->getAsOpaquePointer();
310+
}
311+
}
312+
}
313+
return std::nullopt;
228314
}
229315

230316
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)