Skip to content

Commit cd91bc8

Browse files
committed
Fix bounds of encoded type variables in quote patterns
When we encode quote patterns into `unapply` methods, we need to create a copy of each type variable. One copy is kept within the quote(in the next stage) and the other is used in the `unapply` method to define the usual pattern type variables. When creating the latter we copied the symbols but did not update the infos. This implies that if type variables would be bounded by each other, the bounds of the copies would be the original types instead of the copies. We need to update those references. To update the info we now create all the symbols in one pass and the update all their infos in a second pass. This also implies that we cannot use the `newPatternBoundSymbol` to create the symbol as this constructor will register the info into GADT bounds. Instead we use the plain `newSymbol`. Then in the second pass, when we have updated the infos, we register the symbol into GADT bounds. Note that the code in the added test does compiles correctly, but it had the inconsistent bounds. This test is added in case we need to manually inspect the bounds latter. This test does fail to compile in #17935 if this fix is not applied.
1 parent 229dc12 commit cd91bc8

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,7 @@ trait QuotesAndSplices {
202202
private def splitQuotePattern(quoted: Tree)(using Context): (collection.Map[Symbol, Bind], Tree, List[Tree]) = {
203203
val ctx0 = ctx
204204

205-
val typeBindings: mutable.Map[Symbol, Bind] = mutable.LinkedHashMap.empty
206-
def getBinding(sym: Symbol): Bind =
207-
typeBindings.getOrElseUpdate(sym, {
208-
val bindingBounds = sym.info
209-
val bsym = newPatternBoundSymbol(sym.name.toString.stripPrefix("$").toTypeName, bindingBounds, quoted.span)
210-
Bind(bsym, untpd.Ident(nme.WILDCARD).withType(bindingBounds)).withSpan(quoted.span)
211-
})
205+
val bindSymMapping: collection.Map[Symbol, Bind] = unapplyBindingsMapping(quoted)
212206

213207
object splitter extends tpd.TreeMap {
214208
private var variance: Int = 1
@@ -288,7 +282,7 @@ trait QuotesAndSplices {
288282
report.error(IllegalVariableInPatternAlternative(tdef.symbol.name), tdef.srcPos)
289283
if variance == -1 then
290284
tdef.symbol.addAnnotation(Annotation(New(ref(defn.QuotedRuntimePatterns_fromAboveAnnot.typeRef)).withSpan(tdef.span)))
291-
val bindingType = getBinding(tdef.symbol).symbol.typeRef
285+
val bindingType = bindSymMapping(tdef.symbol).symbol.typeRef
292286
val bindingTypeTpe = AppliedType(defn.QuotedTypeClass.typeRef, bindingType :: Nil)
293287
val sym = newPatternBoundSymbol(nameOfSyntheticGiven, bindingTypeTpe, tdef.span, flags = ImplicitVal)(using ctx0)
294288
buff += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingTypeTpe)).withSpan(tdef.span)
@@ -325,7 +319,51 @@ trait QuotesAndSplices {
325319
new TreeTypeMap(typeMap = typeMap).transform(shape1)
326320
}
327321

328-
(typeBindings, shape2, patterns)
322+
(bindSymMapping, shape2, patterns)
323+
}
324+
325+
/** For each type variable defined in the quote pattern we generate an equivalent
326+
* binding that will be as type variable in the encoded `unapply` of the quote pattern.
327+
*
328+
* @return Mapping from type variable symbols defined in the quote pattern into
329+
* type variable `Bind` definitions for the `unapply` of the quote pattern.
330+
* This mapping retains the original type variable definition order.
331+
*/
332+
private def unapplyBindingsMapping(quoted: Tree)(using Context): collection.Map[Symbol, Bind] = {
333+
val mapping = mutable.LinkedHashMap.empty[Symbol, Symbol]
334+
// Collect all existing type variable bindings and create new symbols for them.
335+
// The old info is used, it may contain references to the old symbols.
336+
new tpd.TreeTraverser {
337+
def traverse(tree: Tree)(using Context): Unit = tree match {
338+
case _: SplicePattern =>
339+
case Select(pat: Bind, _) if tree.symbol.isTypeSplice =>
340+
val sym = tree.tpe.dealias.typeSymbol
341+
if sym.exists then registerNewBindSym(sym)
342+
case tdef: TypeDef =>
343+
if tdef.symbol.hasAnnotation(defn.QuotedRuntimePatterns_patternTypeAnnot) then
344+
registerNewBindSym(tdef.symbol)
345+
traverseChildren(tdef)
346+
case _ =>
347+
traverseChildren(tree)
348+
}
349+
private def registerNewBindSym(sym: Symbol): Unit =
350+
if !mapping.contains(sym) then
351+
mapping(sym) = newSymbol(ctx.owner, sym.name.toString.stripPrefix("$").toTypeName, Case | sym.flags, sym.info, coord = quoted.span)
352+
}.traverse(quoted)
353+
354+
// Replace symbols in `mapping` in the infos of the new symbol and register GADT bounds.
355+
// GADT bounds need to be added after the info is updated to avoid references to the old symbols.
356+
var oldBindings: List[Symbol] = mapping.keys.toList
357+
var newBindingsRefs: List[Type] = mapping.values.toList.map(_.typeRef)
358+
for newBindings <- mapping.values do
359+
newBindings.info = newBindings.info.subst(oldBindings, newBindingsRefs)
360+
ctx.gadtState.addToConstraint(newBindings) // This must be preformed after the info has been updated
361+
362+
// Map into Bind nodes retaining the original order
363+
val mapping2: mutable.Map[Symbol, Bind] = mutable.LinkedHashMap.empty
364+
for (oldSym, newSym) <- mapping do
365+
mapping2(oldSym) = Bind(newSym, untpd.Ident(nme.WILDCARD).withType(newSym.info)).withSpan(quoted.span)
366+
mapping2
329367
}
330368

331369
/** Type a quote pattern `case '{ <quoted> } =>` qiven the a current prototype. Typing the pattern
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import quoted.*
2+
3+
def foo(using Quotes)(x: Expr[Int]) =
4+
x match
5+
case '{ type t; type u <: `t`; f[`t`, `u`] } =>
6+
case '{ type u <: `t`; type t; f[`t`, `u`] } =>
7+
case '{ type t; type u <: `t`; g[F[`t`, `u`]] } =>
8+
case '{ type u <: `t`; type t; g[F[`t`, `u`]] } =>
9+
10+
def f[T, U <: T] = ???
11+
def g[T] = ???
12+
type F[T, U <: T]

0 commit comments

Comments
 (0)