@@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
67
67
}
68
68
69
69
// ===----------------------------------------------------------------------===//
70
- // / AttrTypeReplacer
70
+ // / AttrTypeReplacerBase
71
71
// ===----------------------------------------------------------------------===//
72
72
73
- void AttrTypeReplacer::addReplacement (ReplaceFn<Attribute> fn) {
73
+ template <typename Concrete>
74
+ void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
75
+ ReplaceFn<Attribute> fn) {
74
76
attrReplacementFns.emplace_back (std::move (fn));
75
77
}
76
- void AttrTypeReplacer::addReplacement (ReplaceFn<Type> fn) {
78
+
79
+ template <typename Concrete>
80
+ void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
81
+ ReplaceFn<Type> fn) {
77
82
typeReplacementFns.push_back (std::move (fn));
78
83
}
79
84
80
- void AttrTypeReplacer::replaceElementsIn (Operation *op, bool replaceAttrs,
81
- bool replaceLocs, bool replaceTypes) {
85
+ template <typename Concrete>
86
+ void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn(
87
+ Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
82
88
// Functor that replaces the given element if the new value is different,
83
89
// otherwise returns nullptr.
84
90
auto replaceIfDifferent = [&](auto element) {
85
- auto replacement = replace (element);
91
+ auto replacement = static_cast <Concrete *>( this )-> replace (element);
86
92
return (replacement && replacement != element) ? replacement : nullptr ;
87
93
};
88
94
@@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
127
133
}
128
134
}
129
135
130
- void AttrTypeReplacer::recursivelyReplaceElementsIn (Operation *op,
131
- bool replaceAttrs,
132
- bool replaceLocs,
133
- bool replaceTypes) {
136
+ template <typename Concrete>
137
+ void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn(
138
+ Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
134
139
op->walk ([&](Operation *nestedOp) {
135
140
replaceElementsIn (nestedOp, replaceAttrs, replaceLocs, replaceTypes);
136
141
});
137
142
}
138
143
139
- template <typename T>
140
- static void updateSubElementImpl (T element, AttrTypeReplacer &replacer,
144
+ template <typename T, typename Replacer >
145
+ static void updateSubElementImpl (T element, Replacer &replacer,
141
146
SmallVectorImpl<T> &newElements,
142
147
FailureOr<bool > &changed) {
143
148
// Bail early if we failed at any point.
@@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
160
165
}
161
166
}
162
167
163
- template <typename T>
164
- T AttrTypeReplacer:: replaceSubElements (T interface) {
168
+ template <typename T, typename Replacer >
169
+ static T replaceSubElements (T interface, Replacer &replacer ) {
165
170
// Walk the current sub-elements, replacing them as necessary.
166
171
SmallVector<Attribute, 16 > newAttrs;
167
172
SmallVector<Type, 16 > newTypes;
168
173
FailureOr<bool > changed = false ;
169
174
interface.walkImmediateSubElements (
170
175
[&](Attribute element) {
171
- updateSubElementImpl (element, * this , newAttrs, changed);
176
+ updateSubElementImpl (element, replacer , newAttrs, changed);
172
177
},
173
178
[&](Type element) {
174
- updateSubElementImpl (element, * this , newTypes, changed);
179
+ updateSubElementImpl (element, replacer , newTypes, changed);
175
180
});
176
181
if (failed (changed))
177
182
return nullptr ;
@@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) {
184
189
}
185
190
186
191
// / Shared implementation of replacing a given attribute or type element.
187
- template <typename T, typename ReplaceFns>
188
- T AttrTypeReplacer::replaceImpl (T element, ReplaceFns &replaceFns) {
189
- const void *opaqueElement = element.getAsOpaquePointer ();
190
- auto [it, inserted] = attrTypeMap.try_emplace (opaqueElement, opaqueElement);
191
- if (!inserted)
192
- return T::getFromOpaquePointer (it->second );
193
-
192
+ template <typename T, typename ReplaceFns, typename Replacer>
193
+ static T replaceElementImpl (T element, ReplaceFns &replaceFns,
194
+ Replacer &replacer) {
194
195
T result = element;
195
196
WalkResult walkResult = WalkResult::advance ();
196
197
for (auto &replaceFn : llvm::reverse (replaceFns)) {
@@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
202
203
203
204
// If an error occurred, return nullptr to indicate failure.
204
205
if (walkResult.wasInterrupted () || !result) {
205
- attrTypeMap[opaqueElement] = nullptr ;
206
206
return nullptr ;
207
207
}
208
208
209
209
// Handle replacing sub-elements if this element is also a container.
210
210
if (!walkResult.wasSkipped ()) {
211
211
// Replace the sub elements of this element, bailing if we fail.
212
- if (!(result = replaceSubElements (result))) {
213
- attrTypeMap[opaqueElement] = nullptr ;
212
+ if (!(result = replaceSubElements (result, replacer))) {
214
213
return nullptr ;
215
214
}
216
215
}
217
216
218
- attrTypeMap[opaqueElement] = result.getAsOpaquePointer ();
217
+ return result;
218
+ }
219
+
220
+ template <typename Concrete>
221
+ Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) {
222
+ return replaceElementImpl (attr, attrReplacementFns,
223
+ *static_cast <Concrete *>(this ));
224
+ }
225
+
226
+ template <typename Concrete>
227
+ Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) {
228
+ return replaceElementImpl (type, typeReplacementFns,
229
+ *static_cast <Concrete *>(this ));
230
+ }
231
+
232
+ // ===----------------------------------------------------------------------===//
233
+ // / AttrTypeReplacer
234
+ // ===----------------------------------------------------------------------===//
235
+
236
+ template class detail ::AttrTypeReplacerBase<AttrTypeReplacer>;
237
+
238
+ template <typename T>
239
+ T AttrTypeReplacer::cachedReplaceImpl (T element) {
240
+ const void *opaqueElement = element.getAsOpaquePointer ();
241
+ auto [it, inserted] = cache.try_emplace (opaqueElement, opaqueElement);
242
+ if (!inserted)
243
+ return T::getFromOpaquePointer (it->second );
244
+
245
+ T result = replaceBase (element);
246
+
247
+ cache[opaqueElement] = result.getAsOpaquePointer ();
219
248
return result;
220
249
}
221
250
222
251
Attribute AttrTypeReplacer::replace (Attribute attr) {
223
- return replaceImpl (attr, attrReplacementFns );
252
+ return cachedReplaceImpl (attr);
224
253
}
225
254
226
- Type AttrTypeReplacer::replace (Type type) {
227
- return replaceImpl (type, typeReplacementFns);
255
+ Type AttrTypeReplacer::replace (Type type) { return cachedReplaceImpl (type); }
256
+
257
+ // ===----------------------------------------------------------------------===//
258
+ // / CyclicAttrTypeReplacer
259
+ // ===----------------------------------------------------------------------===//
260
+
261
+ template class detail ::AttrTypeReplacerBase<CyclicAttrTypeReplacer>;
262
+
263
+ CyclicAttrTypeReplacer::CyclicAttrTypeReplacer ()
264
+ : cache([&](void *attr) { return breakCycleImpl (attr); }) {}
265
+
266
+ void CyclicAttrTypeReplacer::addCycleBreaker (CycleBreakerFn<Attribute> fn) {
267
+ attrCycleBreakerFns.emplace_back (std::move (fn));
268
+ }
269
+
270
+ void CyclicAttrTypeReplacer::addCycleBreaker (CycleBreakerFn<Type> fn) {
271
+ typeCycleBreakerFns.emplace_back (std::move (fn));
272
+ }
273
+
274
+ template <typename T>
275
+ T CyclicAttrTypeReplacer::cachedReplaceImpl (T element) {
276
+ void *opaqueTaggedElement = AttrOrType (element).getOpaqueValue ();
277
+ CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry =
278
+ cache.lookupOrInit (opaqueTaggedElement);
279
+ if (auto resultOpt = cacheEntry.get ())
280
+ return T::getFromOpaquePointer (*resultOpt);
281
+
282
+ T result = replaceBase (element);
283
+
284
+ cacheEntry.resolve (result.getAsOpaquePointer ());
285
+ return result;
286
+ }
287
+
288
+ Attribute CyclicAttrTypeReplacer::replace (Attribute attr) {
289
+ return cachedReplaceImpl (attr);
290
+ }
291
+
292
+ Type CyclicAttrTypeReplacer::replace (Type type) {
293
+ return cachedReplaceImpl (type);
294
+ }
295
+
296
+ std::optional<const void *>
297
+ CyclicAttrTypeReplacer::breakCycleImpl (void *element) {
298
+ AttrOrType attrType = AttrOrType::getFromOpaqueValue (element);
299
+ if (auto attr = dyn_cast<Attribute>(attrType)) {
300
+ for (auto &cyclicReplaceFn : llvm::reverse (attrCycleBreakerFns)) {
301
+ if (std::optional<Attribute> newRes = cyclicReplaceFn (attr)) {
302
+ return newRes->getAsOpaquePointer ();
303
+ }
304
+ }
305
+ } else {
306
+ auto type = dyn_cast<Type>(attrType);
307
+ for (auto &cyclicReplaceFn : llvm::reverse (typeCycleBreakerFns)) {
308
+ if (std::optional<Type> newRes = cyclicReplaceFn (type)) {
309
+ return newRes->getAsOpaquePointer ();
310
+ }
311
+ }
312
+ }
313
+ return std::nullopt;
228
314
}
229
315
230
316
// ===----------------------------------------------------------------------===//
0 commit comments