Skip to content

Commit ec65eb3

Browse files
committed
Fix captured references to singleton types
When we had a reference to a `x.type` we mistakenly captured `x` instead of `x.type`. This was caused because `SingletonTypeTree` was not handled in `Splicing`. Fixing this uncovered some inconsistencies with the types in the encoding of the hole captured types and contents. These have been fixed as well.
1 parent bc90b96 commit ec65eb3

File tree

5 files changed

+48
-5
lines changed

5 files changed

+48
-5
lines changed

compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,20 @@ class PickleQuotes extends MacroTransform {
157157
override def apply(tp: Type): Type = tp match
158158
case tp: TypeRef if tp.typeSymbol.isTypeSplice =>
159159
apply(tp.dealias)
160-
case tp @ TypeRef(pre, _) if pre == NoPrefix || pre.termSymbol.isLocal =>
160+
case tp @ TypeRef(pre, _) if isLocalPath(pre) =>
161161
val hiBound = tp.typeSymbol.info match
162162
case info: ClassInfo => info.parents.reduce(_ & _)
163163
case info => info.hiBound
164164
apply(hiBound)
165+
case tp @ TermRef(pre, _) if isLocalPath(pre) =>
166+
apply(tp.widenTermRefExpr)
165167
case tp =>
166168
mapOver(tp)
169+
170+
private def isLocalPath(tp: Type): Boolean = tp match
171+
case NoPrefix => true
172+
case tp: TermRef if !tp.symbol.is(Package) => isLocalPath(tp.prefix)
173+
case tp => false
167174
}
168175

169176
/** Remove references to local types that will not be defined in this quote */

compiler/src/dotty/tools/dotc/transform/Splicing.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class Splicing extends MacroTransform:
232232
// Dealias references to captured types
233233
TypeTree(tree.tpe.dealias)
234234
else super.transform(tree)
235-
case tree: TypeTree =>
235+
case _: TypeTree | _: SingletonTypeTree =>
236236
if containsCapturedType(tree.tpe) && level >= 1 then getTagRefFor(tree)
237237
else tree
238238
case tree @ Assign(lhs: RefTree, rhs) =>
@@ -361,9 +361,8 @@ class Splicing extends MacroTransform:
361361
)
362362

363363
private def capturedType(tree: Tree)(using Context): Symbol =
364-
val tpe = tree.tpe.widenTermRefExpr
365364
val bindingSym = refBindingMap
366-
.getOrElseUpdate(tree.symbol, (TypeTree(tree.tpe), newQuotedTypeClassBinding(tpe)))._2
365+
.getOrElseUpdate(tree.symbol, (TypeTree(tree.tpe), newQuotedTypeClassBinding(tree.tpe)))._2
367366
bindingSym
368367

369368
private def capturedPartTypes(tpt: Tree)(using Context): Tree =

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ object TreeChecker {
658658
// Check that we only add the captured type `T` instead of a more complex type like `List[T]`.
659659
// If we have `F[T]` with captured `F` and `T`, we should list `F` and `T` separately in the args.
660660
for arg <- args do
661-
assert(arg.isTerm || arg.tpe.isInstanceOf[TypeRef], "Expected TypeRef in Hole type args but got: " + arg.tpe)
661+
assert(arg.isTerm || arg.tpe.isInstanceOf[TypeRef] || arg.tpe.isInstanceOf[TermRef], "Expected TypeRef or TermRef in Hole type args but got: " + arg.tpe)
662662

663663
// Check result type of the hole
664664
if isTermHole then assert(tpt.typeOpt <:< pt)

tests/neg-macros/i17103.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted.*
2+
3+
def test(using Quotes): Expr[Unit] =
4+
'{
5+
trait C:
6+
def d: Int
7+
val c: C = ???
8+
${
9+
val expr = '{
10+
val cRef: c.type = ???
11+
cRef.d // error
12+
()
13+
}
14+
expr
15+
}
16+
}

tests/pos-macros/i17103b.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import scala.quoted.*
2+
3+
trait C0:
4+
def d: Int
5+
6+
def test(using Quotes): Expr[Unit] =
7+
'{
8+
trait C1 extends C0:
9+
def d: Int
10+
trait C extends C1:
11+
def d: Int
12+
val c: C = ???
13+
${
14+
val expr = '{
15+
val cRef: c.type = ???
16+
cRef.d // calls C0.d
17+
()
18+
}
19+
expr
20+
}
21+
}

0 commit comments

Comments
 (0)