Skip to content

Commit 901cdf9

Browse files
committed
Improve handling of owners in reflection API
* Add `changeOwner` * Add `changeNonLocalOwners` * Add symbol to `Lambda.unapply` * Ycheck the owners durring macro expansion * Fix scala#10151 * Fix scala#10211
1 parent cd15c99 commit 901cdf9

File tree

12 files changed

+80
-19
lines changed

12 files changed

+80
-19
lines changed

compiler/src/scala/quoted/internal/impl/Matcher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ object Matcher {
211211
}
212212
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
213213
val resType = pattern.tpe
214-
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
214+
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), x => bodyFn(x).changeNonLocalOwners(Symbol.currentOwner))
215215
matched(res.asExpr)
216216

217217
//

compiler/src/scala/quoted/internal/impl/QuoteContextImpl.scala

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ object QuoteContextImpl {
4545

4646
class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickler, QuoteMatching:
4747

48+
private val yCheck: Boolean =
49+
ctx.settings.Ycheck.value(using ctx).exists(x => x == "all" || x == "macros")
50+
4851
extension [T](self: scala.quoted.Expr[T]):
4952
def show: String =
5053
reflect.TreeMethodsImpl.show(reflect.Term.of(self))
@@ -118,6 +121,13 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
118121
QuoteContextImpl.this.asExprOf[T](self.asExpr)(using tp)
119122
end extension
120123

124+
extension [ThisTree <: Tree](self: ThisTree):
125+
def changeOwner(from: Symbol, to: Symbol): ThisTree =
126+
tpd.TreeOps(self).changeOwner(from, to)
127+
def changeNonLocalOwners(newOwner: Symbol): ThisTree =
128+
tpd.TreeOps(self).changeNonLocalOwners(newOwner).asInstanceOf[ThisTree]
129+
end extension
130+
121131
end TreeMethodsImpl
122132

123133
type PackageClause = tpd.PackageDef
@@ -238,9 +248,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
238248

239249
object DefDef extends DefDefModule:
240250
def apply(symbol: Symbol, rhsFn: List[TypeRepr] => List[List[Term]] => Option[Term]): DefDef =
241-
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => rhsFn(tparams)(vparamss).getOrElse(tpd.EmptyTree)))
251+
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => yCheckedOwners(rhsFn(tparams)(vparamss), symbol).getOrElse(tpd.EmptyTree)))
242252
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
243-
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, rhs.getOrElse(tpd.EmptyTree))
253+
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
244254
def unapply(ddef: DefDef): Option[(String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term])] =
245255
Some((ddef.name.toString, ddef.typeParams, ddef.paramss, ddef.tpt, optional(ddef.rhs)))
246256
end DefDef
@@ -264,9 +274,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
264274

265275
object ValDef extends ValDefModule:
266276
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
267-
tpd.ValDef(symbol.asTerm, rhs.getOrElse(tpd.EmptyTree))
277+
tpd.ValDef(symbol.asTerm, yCheckedOwners(rhs, symbol).getOrElse(tpd.EmptyTree))
268278
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
269-
tpd.cpy.ValDef(original)(name.toTermName, tpt, rhs.getOrElse(tpd.EmptyTree))
279+
tpd.cpy.ValDef(original)(name.toTermName, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
270280
def unapply(vdef: ValDef): Option[(String, TypeTree, Option[Term])] =
271281
Some((vdef.name.toString, vdef.tpt, optional(vdef.rhs)))
272282

@@ -729,12 +739,12 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
729739
object Lambda extends LambdaModule:
730740
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree): Block =
731741
val meth = dotc.core.Symbols.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe)
732-
tpd.Closure(meth, tss => changeOwnerOfTree(rhsFn(tss.head), meth))
742+
tpd.Closure(meth, tss => tpd.TreeOps(yCheckedOwners(rhsFn(tss.head), ctx.owner)).changeOwner(ctx.owner, meth))
733743

734-
def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
744+
def unapply(tree: Block): Option[(Symbol, List[ValDef], Term)] = tree match {
735745
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
736746
if ddef.symbol == meth.symbol =>
737-
Some((params, body))
747+
Some((meth.symbol, params, body))
738748
case _ => None
739749
}
740750
end Lambda
@@ -2541,6 +2551,44 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
25412551
private def withDefaultPos[T <: Tree](fn: Context ?=> T): T =
25422552
fn(using ctx.withSource(Position.ofMacroExpansion.source)).withSpan(Position.ofMacroExpansion.span)
25432553

2554+
private def yCheckedOwners(tree: Option[Tree], owner: Symbol): tree.type =
2555+
if yCheck then
2556+
tree match
2557+
case Some(tree) =>
2558+
yCheckOwners(tree, owner)
2559+
case _ =>
2560+
tree
2561+
2562+
private def yCheckedOwners(tree: Tree, owner: Symbol): tree.type =
2563+
if yCheck then
2564+
yCheckOwners(tree, owner)
2565+
tree
2566+
2567+
private def yCheckOwners(tree: Tree, owner: Symbol): Unit =
2568+
new tpd.TreeTraverser {
2569+
def traverse(t: Tree)(using Context): Unit =
2570+
t match
2571+
case t: tpd.DefTree =>
2572+
val defOwner = t.symbol.owner
2573+
assert(defOwner == owner,
2574+
s"""Tree had an unexpected owner for ${t.symbol}
2575+
|Expected: $owner (${owner.fullName})
2576+
|But was: $defOwner (${defOwner.fullName})
2577+
|
2578+
|
2579+
|The code of the definition of ${t.symbol} is
2580+
|${TreeMethods.show(t)}
2581+
|
2582+
|which was found in the code
2583+
|${TreeMethods.show(tree)}
2584+
|
2585+
|which has the AST representation
2586+
|${TreeMethods.showExtractors(tree)}
2587+
|
2588+
|""".stripMargin) // TODO add a good error message
2589+
case _ => traverseChildren(t)
2590+
}.traverse(tree)
2591+
25442592
end reflect
25452593

25462594
def unpickleExpr[T](pickled: String | List[String], typeHole: (Int, Seq[Any]) => scala.quoted.Type[?], termHole: (Int, Seq[Any], scala.quoted.QuoteContext) => scala.quoted.Expr[?]): scala.quoted.Expr[T] =

compiler/src/scala/quoted/internal/impl/printers/SourceCode.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ object SourceCode {
456456
this += " = "
457457
printTree(rhs)
458458

459-
case tree @ Lambda(params, body) => // must come before `Block`
459+
case tree @ Lambda(_, params, body) => // must come before `Block`
460460
inParens {
461461
printArgsDefs(params)
462462
this += (if tree.tpe.isContextFunctionType then " ?=> " else " => ")
@@ -543,7 +543,7 @@ object SourceCode {
543543
private def flatBlock(stats: List[Statement], expr: Term): (List[Statement], Term) = {
544544
val flatStats = List.newBuilder[Statement]
545545
def extractFlatStats(stat: Statement): Unit = stat match {
546-
case Lambda(_, _) => // must come before `Block`
546+
case Lambda(_, _, _) => // must come before `Block`
547547
flatStats += stat
548548
case Block(stats1, expr1) =>
549549
val it = stats1.iterator
@@ -559,7 +559,7 @@ object SourceCode {
559559
case stat => flatStats += stat
560560
}
561561
def extractFlatExpr(term: Term): Term = term match {
562-
case Lambda(_, _) => // must come before `Block`
562+
case Lambda(_, _, _) => // must come before `Block`
563563
term
564564
case Block(stats1, expr1) =>
565565
val it = stats1.iterator
@@ -601,7 +601,7 @@ object SourceCode {
601601
def printSeparator(next: Tree): Unit = {
602602
// Avoid accidental application of opening `{` on next line with a double break
603603
def rec(next: Tree): Unit = next match {
604-
case Lambda(_, _) => this += lineBreak()
604+
case Lambda(_, _, _) => this += lineBreak()
605605
case Block(stats, _) if stats.nonEmpty => this += doubleLineBreak()
606606
case Inlined(_, bindings, _) if bindings.nonEmpty => this += doubleLineBreak()
607607
case Select(qual, _) => rec(qual)

library/src/scala/quoted/QuoteContext.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
182182
// CONTEXTS //
183183
//////////////
184184

185-
/** Compilation context */
185+
/** Context containing information on the current owner */
186186
type Context <: AnyRef
187187

188188
/** Context of the macro expansion */
@@ -238,6 +238,14 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
238238
/** Convert this tree to an `quoted.Expr[T]` if the tree is a valid expression or throws */
239239
extension [T](self: Tree)
240240
def asExprOf(using scala.quoted.Type[T]): scala.quoted.Expr[T]
241+
242+
extension [ThisTree <: Tree](self: ThisTree):
243+
/** TODO */
244+
def changeOwner(from: Symbol, to: Symbol): ThisTree
245+
/** TODO */
246+
def changeNonLocalOwners(newOwner: Symbol): ThisTree
247+
end extension
248+
241249
}
242250

243251
/** Tree representing a pacakage clause in the source code */
@@ -962,7 +970,12 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
962970
val Lambda: LambdaModule
963971

964972
trait LambdaModule { this: Lambda.type =>
965-
def unapply(tree: Block): Option[(List[ValDef], Term)]
973+
def unapply(tree: Block): Option[(Symbol, List[ValDef], Term)]
974+
975+
/* Generates a lambda with the given method type.
976+
*
977+
* Any definition in `rhs` is expected to be owned by the current owner in context.
978+
*/
966979
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree): Block
967980
}
968981

tests/neg-staging/i5941/macro_1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ object Lens {
2525
None, Nil,
2626
Block(
2727
DefDef(_, Nil, (param :: Nil) :: Nil, _, Some(Select(o, field))) :: Nil,
28-
Lambda(meth, _)
28+
Lambda(_, meth, _)
2929
)
3030
) if o.symbol == param.symbol =>
3131
'{

tests/pos-macros/i10151/Macro_1.scala

Whitespace-only changes.

tests/pos-macros/i10151/Test_2.scala

Whitespace-only changes.

tests/pos-macros/i10211/Macro_1.scala

Whitespace-only changes.

tests/pos-macros/i10211/Test_2.scala

Whitespace-only changes.

tests/pos-macros/i9894/Macro_1.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ object X:
4242

4343
def shiftLambda(term:Term): Term =
4444
term match
45-
case lt@Lambda(params, body) =>
45+
case lt@Lambda(meth, params, body) =>
4646
val paramTypes = params.map(_.tpt.tpe)
4747
val paramNames = params.map(_.name)
4848
val mt = MethodType(paramNames)(_ => paramTypes, _ => TypeRepr.of[CB].appliedTo(body.tpe.widen) )
49-
val r = Lambda(mt, args => changeArgs(params,args,transform(body)) )
49+
val r = Lambda(mt, args => changeArgs(params,args,transform(body).changeOwner(meth, Symbol.currentOwner)) )
5050
r
5151
case _ =>
5252
throw RuntimeException("lambda expected")

tests/run-macros/i5941/macro_1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ object Lens {
4040

4141
object Function {
4242
def unapply(t: Term): Option[(List[ValDef], Term)] = t match {
43-
case Inlined(None, Nil, Lambda(params, body)) => Some((params, body))
43+
case Inlined(None, Nil, Lambda(_, params, body)) => Some((params, body))
4444
case _ => None
4545
}
4646
}

tests/run-macros/reflect-lambda/assert_1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ object lib {
99
import util._
1010

1111
Term.of(cond).underlyingArgument match {
12-
case t @ Apply(Select(lhs, op), Lambda(param :: Nil, Apply(Select(a, "=="), b :: Nil)) :: Nil)
12+
case t @ Apply(Select(lhs, op), Lambda(_, param :: Nil, Apply(Select(a, "=="), b :: Nil)) :: Nil)
1313
if a.symbol == param.symbol || b.symbol == param.symbol =>
1414
'{ scala.Predef.assert($cond) }
1515
}

0 commit comments

Comments
 (0)