Skip to content

Commit 7fd358b

Browse files
committed
Fix bounds of encoded type variables in quote patterns
When we encode quote patterns into `unapply` methods, we need to create a copy of each type variable. One copy is kept within the quote(in the next stage) and the other is used in the `unapply` method to define the usual pattern type variables. When creating the latter we copied the symbols but did not update the infos. This implies that if type variables would be bounded by each other, the bounds of the copies would be the original types instead of the copies. We need to update those references. To update the info we now create all the symbols in one pass and the update all their infos in a second pass. This also implies that we cannot use the `newPatternBoundSymbol` to create the symbol as this constructor will register the info into GADT bounds. Instead we use the plain `newSymbol`. Then in the second pass, when we have updated the infos, we register the symbol into GADT bounds. Note that the code in the added test does compiles correctly, but it had the inconsistent bounds. This test is added in case we need to manually inspect the bounds latter. This test does fail to compile in #17935 if this fix is not applied.
1 parent c528979 commit 7fd358b

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,7 @@ trait QuotesAndSplices {
214214
private def splitQuotePattern(quoted: Tree)(using Context): (collection.Map[Symbol, Bind], Tree, List[Tree]) = {
215215
val ctx0 = ctx
216216

217-
val typeBindings: mutable.Map[Symbol, Bind] = mutable.LinkedHashMap.empty
218-
def getBinding(sym: Symbol): Bind =
219-
typeBindings.getOrElseUpdate(sym, {
220-
val bindingBounds = sym.info
221-
val bsym = newPatternBoundSymbol(sym.name.toString.stripPrefix("$").toTypeName, bindingBounds, quoted.span)
222-
Bind(bsym, untpd.Ident(nme.WILDCARD).withType(bindingBounds)).withSpan(quoted.span)
223-
})
217+
val bindSymMapping: collection.Map[Symbol, Bind] = unapplyBindingsMapping(quoted)
224218

225219
object splitter extends tpd.TreeMap {
226220
private var variance: Int = 1
@@ -300,7 +294,7 @@ trait QuotesAndSplices {
300294
report.error(IllegalVariableInPatternAlternative(tdef.symbol.name), tdef.srcPos)
301295
if variance == -1 then
302296
tdef.symbol.addAnnotation(Annotation(New(ref(defn.QuotedRuntimePatterns_fromAboveAnnot.typeRef)).withSpan(tdef.span)))
303-
val bindingType = getBinding(tdef.symbol).symbol.typeRef
297+
val bindingType = bindSymMapping(tdef.symbol).symbol.typeRef
304298
val bindingTypeTpe = AppliedType(defn.QuotedTypeClass.typeRef, bindingType :: Nil)
305299
val sym = newPatternBoundSymbol(nameOfSyntheticGiven, bindingTypeTpe, tdef.span, flags = ImplicitVal)(using ctx0)
306300
buff += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingTypeTpe)).withSpan(tdef.span)
@@ -337,7 +331,51 @@ trait QuotesAndSplices {
337331
new TreeTypeMap(typeMap = typeMap).transform(shape1)
338332
}
339333

340-
(typeBindings, shape2, patterns)
334+
(bindSymMapping, shape2, patterns)
335+
}
336+
337+
/** For each type variable defined in the quote pattern we generate an equivalent
338+
* binding that will be as type variable in the encoded `unapply` of the quote pattern.
339+
*
340+
* @return Mapping from type variable symbols defined in the quote pattern into
341+
* type variable `Bind` definitions for the `unapply` of the quote pattern.
342+
* This mapping retains the original type variable definition order.
343+
*/
344+
private def unapplyBindingsMapping(quoted: Tree)(using Context): collection.Map[Symbol, Bind] = {
345+
val mapping = mutable.LinkedHashMap.empty[Symbol, Symbol]
346+
// Collect all existing type variable bindings and create new symbols for them.
347+
// The old info is used, it may contain references to the old symbols.
348+
new tpd.TreeTraverser {
349+
def traverse(tree: Tree)(using Context): Unit = tree match {
350+
case _: SplicePattern =>
351+
case Select(pat: Bind, _) if tree.symbol.isTypeSplice =>
352+
val sym = tree.tpe.dealias.typeSymbol
353+
if sym.exists then registerNewBindSym(sym)
354+
case tdef: TypeDef =>
355+
if tdef.symbol.hasAnnotation(defn.QuotedRuntimePatterns_patternTypeAnnot) then
356+
registerNewBindSym(tdef.symbol)
357+
traverseChildren(tdef)
358+
case _ =>
359+
traverseChildren(tree)
360+
}
361+
private def registerNewBindSym(sym: Symbol): Unit =
362+
if !mapping.contains(sym) then
363+
mapping(sym) = newSymbol(ctx.owner, sym.name.toString.stripPrefix("$").toTypeName, Case | sym.flags, sym.info, coord = quoted.span)
364+
}.traverse(quoted)
365+
366+
// Replace symbols in `mapping` in the infos of the new symbol and register GADT bounds.
367+
// GADT bounds need to be added after the info is updated to avoid references to the old symbols.
368+
var oldBindings: List[Symbol] = mapping.keys.toList
369+
var newBindingsRefs: List[Type] = mapping.values.toList.map(_.typeRef)
370+
for newBindings <- mapping.values do
371+
newBindings.info = newBindings.info.subst(oldBindings, newBindingsRefs)
372+
ctx.gadtState.addToConstraint(newBindings) // This must be preformed after the info has been updated
373+
374+
// Map into Bind nodes retaining the original order
375+
val mapping2: mutable.Map[Symbol, Bind] = mutable.LinkedHashMap.empty
376+
for (oldSym, newSym) <- mapping do
377+
mapping2(oldSym) = Bind(newSym, untpd.Ident(nme.WILDCARD).withType(newSym.info)).withSpan(quoted.span)
378+
mapping2
341379
}
342380

343381
/** Type a quote pattern `case '{ <quoted> } =>` qiven the a current prototype. Typing the pattern
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import quoted.*
2+
3+
def foo(using Quotes)(x: Expr[Int]) =
4+
x match
5+
case '{ type t; type u <: `t`; f[`t`, `u`] } =>
6+
case '{ type u <: `t`; type t; f[`t`, `u`] } =>
7+
case '{ type t; type u <: `t`; g[F[`t`, `u`]] } =>
8+
case '{ type u <: `t`; type t; g[F[`t`, `u`]] } =>
9+
10+
def f[T, U <: T] = ???
11+
def g[T] = ???
12+
type F[T, U <: T]

0 commit comments

Comments
 (0)