Skip to content

Commit 6a3302e

Browse files
committed
Fix captured references to singleton types
When we had a reference to a `x.type` we mistakenly captured `x` instead of `x.type`. This was caused because `SingletonTypeTree` was not handled in `Splicing`. Fixing this uncovered some inconsistencies with the types in the encoding of the hole captured types and contents. These have been fixed as well.
1 parent d36cd2d commit 6a3302e

File tree

5 files changed

+38
-5
lines changed

5 files changed

+38
-5
lines changed

compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,20 @@ class PickleQuotes extends MacroTransform {
157157
override def apply(tp: Type): Type = tp match
158158
case tp: TypeRef if tp.typeSymbol.isTypeSplice =>
159159
apply(tp.dealias)
160-
case tp @ TypeRef(pre, _) if pre == NoPrefix || pre.termSymbol.isLocal =>
160+
case tp @ TypeRef(pre, _) if isLocalPath(pre) =>
161161
val hiBound = tp.typeSymbol.info match
162162
case info: ClassInfo => info.parents.reduce(_ & _)
163163
case info => info.hiBound
164164
apply(hiBound)
165+
case tp @ TermRef(pre, _) if isLocalPath(pre) =>
166+
apply(tp.widenTermRefExpr)
165167
case tp =>
166168
mapOver(tp)
169+
170+
private def isLocalPath(tp: Type): Boolean = tp match
171+
case NoPrefix => true
172+
case tp: TermRef if !tp.symbol.is(Package) => isLocalPath(tp.prefix)
173+
case tp => false
167174
}
168175

169176
/** Remove references to local types that will not be defined in this quote */

compiler/src/dotty/tools/dotc/transform/Splicing.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class Splicing extends MacroTransform:
231231
// Dealias references to captured types
232232
TypeTree(tree.tpe.dealias)
233233
else super.transform(tree)
234-
case tree: TypeTree =>
234+
case _: TypeTree | _: SingletonTypeTree =>
235235
if containsCapturedType(tree.tpe) && level >= 1 then getTagRefFor(tree)
236236
else tree
237237
case tree @ Assign(lhs: RefTree, rhs) =>
@@ -360,9 +360,8 @@ class Splicing extends MacroTransform:
360360
)
361361

362362
private def capturedType(tree: Tree)(using Context): Symbol =
363-
val tpe = tree.tpe.widenTermRefExpr
364363
val bindingSym = refBindingMap
365-
.getOrElseUpdate(tree.symbol, (TypeTree(tree.tpe), newQuotedTypeClassBinding(tpe)))._2
364+
.getOrElseUpdate(tree.symbol, (TypeTree(tree.tpe), newQuotedTypeClassBinding(tree.tpe)))._2
366365
bindingSym
367366

368367
private def capturedPartTypes(tpt: Tree)(using Context): Tree =

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ object TreeChecker {
658658
// Check that we only add the captured type `T` instead of a more complex type like `List[T]`.
659659
// If we have `F[T]` with captured `F` and `T`, we should list `F` and `T` separately in the args.
660660
for arg <- args do
661-
assert(arg.isTerm || arg.tpe.isInstanceOf[TypeRef], "Expected TypeRef in Hole type args but got: " + arg.tpe)
661+
assert(arg.isTerm || arg.tpe.isInstanceOf[TypeRef] || arg.tpe.isInstanceOf[TermRef], "Expected TypeRef or TermRef in Hole type args but got: " + arg.tpe)
662662

663663
// Check result type of the hole
664664
if isTermHole then assert(tpt.typeOpt <:< pt)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import scala.quoted.*
2+
3+
object Macro:
4+
inline def generateCode: Unit = ${ testThisPaths }
5+
6+
def testThisPaths(using Quotes): Expr[Unit] =
7+
'{
8+
trait E:
9+
type V
10+
val f: F
11+
${
12+
val expr = '{
13+
val _: Any = this
14+
val _: Any = f
15+
val _: this.type = ???
16+
val _: V = ???
17+
val _: this.V = ???
18+
val _: this.f.V = ???
19+
val _: this.type = ???
20+
val _: this.f.type = ???
21+
}
22+
expr
23+
}
24+
trait F:
25+
type V
26+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
@main def test = Macro.generateCode

0 commit comments

Comments
 (0)