Skip to content

Commit bd9926d

Browse files
Merge pull request #11591 from dotty-staging/fix-11538
Use transformStats to handle imports
2 parents 3bb16f8 + d5fa3f2 commit bd9926d

13 files changed

+97
-19
lines changed

compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TreeMapWithImplicits extends tpd.TreeMap {
2626
* - be tail-recursive where possible
2727
* - don't re-allocate trees where nothing has changed
2828
*/
29-
def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
29+
override def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
3030

3131
@tailrec def traverse(curStats: List[Tree])(using Context): List[Tree] = {
3232

@@ -88,8 +88,14 @@ class TreeMapWithImplicits extends tpd.TreeMap {
8888
def localCtx =
8989
if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx
9090
try tree match {
91-
case tree: Block =>
92-
super.transform(tree)(using nestedScopeCtx(tree.stats))
91+
case Block(stats, expr) =>
92+
inContext(nestedScopeCtx(stats)) {
93+
if stats.exists(_.isInstanceOf[Import]) then
94+
// need to transform stats and expr together to account for import visibility
95+
val stats1 = transformStats(stats :+ expr, ctx.owner)
96+
cpy.Block(tree)(stats1.init, stats1.last)
97+
else super.transform(tree)
98+
}
9399
case tree: DefDef =>
94100
inContext(localCtx) {
95101
cpy.DefDef(tree)(
@@ -100,8 +106,10 @@ class TreeMapWithImplicits extends tpd.TreeMap {
100106
}
101107
case EmptyValDef =>
102108
tree
103-
case _: PackageDef | _: MemberDef =>
109+
case _: MemberDef =>
104110
super.transform(tree)(using localCtx)
111+
case _: PackageDef =>
112+
super.transform(tree)(using ctx.withOwner(tree.symbol.moduleClass))
105113
case impl @ Template(constr, parents, self, _) =>
106114
cpy.Template(tree)(
107115
transformSub(constr),

compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class TreeTypeMap(
130130
}
131131
}
132132

133-
override def transformStats(trees: List[tpd.Tree])(using Context): List[Tree] =
133+
override def transformStats(trees: List[tpd.Tree], exprOwner: Symbol)(using Context): List[Tree] =
134134
transformDefs(trees)._2
135135

136136
def transformDefs[TT <: tpd.Tree](trees: List[TT])(using Context): (TreeTypeMap, List[TT]) = {

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,7 @@ object Trees {
13321332
case Assign(lhs, rhs) =>
13331333
cpy.Assign(tree)(transform(lhs), transform(rhs))
13341334
case Block(stats, expr) =>
1335-
cpy.Block(tree)(transformStats(stats), transform(expr))
1335+
cpy.Block(tree)(transformStats(stats, ctx.owner), transform(expr))
13361336
case If(cond, thenp, elsep) =>
13371337
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
13381338
case Closure(env, meth, tpt) =>
@@ -1398,13 +1398,13 @@ object Trees {
13981398
cpy.TypeDef(tree)(name, transform(rhs))
13991399
}
14001400
case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty =>
1401-
cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body))
1401+
cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body, tree.symbol))
14021402
case Import(expr, selectors) =>
14031403
cpy.Import(tree)(transform(expr), selectors)
14041404
case Export(expr, selectors) =>
14051405
cpy.Export(tree)(transform(expr), selectors)
14061406
case PackageDef(pid, stats) =>
1407-
cpy.PackageDef(tree)(transformSub(pid), transformStats(stats)(using localCtx))
1407+
cpy.PackageDef(tree)(transformSub(pid), transformStats(stats, pid.symbol.moduleClass)(using localCtx))
14081408
case Annotated(arg, annot) =>
14091409
cpy.Annotated(tree)(transform(arg), transform(annot))
14101410
case Thicket(trees) =>
@@ -1416,7 +1416,7 @@ object Trees {
14161416
}
14171417
}
14181418

1419-
def transformStats(trees: List[Tree])(using Context): List[Tree] =
1419+
def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
14201420
transform(trees)
14211421
def transform(trees: List[Tree])(using Context): List[Tree] =
14221422
flatten(trees mapConserve (transform(_)))

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
659659
case ModuleDef(name, impl) =>
660660
cpy.ModuleDef(tree)(name, transformSub(impl))
661661
case tree: DerivingTemplate =>
662-
cpy.Template(tree)(transformSub(tree.constr), transform(tree.parents), transform(tree.derived), transformSub(tree.self), transformStats(tree.body))
662+
cpy.Template(tree)(transformSub(tree.constr), transform(tree.parents),
663+
transform(tree.derived), transformSub(tree.self), transformStats(tree.body, tree.symbol))
663664
case ParsedTry(expr, handler, finalizer) =>
664665
cpy.ParsedTry(tree)(transform(expr), transform(handler), transform(finalizer))
665666
case SymbolLit(str) =>

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,11 +538,17 @@ object Contexts {
538538
case _ => new Typer
539539
}
540540

541-
override def toString: String = {
542-
def iinfo(using Context) = if (ctx.importInfo == null) "" else i"${ctx.importInfo.selectors}%, %"
543-
"Context(\n" +
544-
(outersIterator.map(ctx => s" owner = ${ctx.owner}, scope = ${ctx.scope}, import = ${iinfo(using ctx)}").mkString("\n"))
545-
}
541+
override def toString: String =
542+
def iinfo(using Context) =
543+
if (ctx.importInfo == null) "" else i"${ctx.importInfo.selectors}%, %"
544+
def cinfo(using Context) =
545+
val core = s" owner = ${ctx.owner}, scope = ${ctx.scope}, import = ${iinfo(using ctx)}"
546+
if (ctx ne NoContext) && (ctx.implicits ne ctx.outer.implicits) then
547+
s"$core, implicits = ${ctx.implicits}"
548+
else
549+
core
550+
s"""Context(
551+
|${outersIterator.map(ctx => cinfo(using ctx)).mkString("\n\n")})""".stripMargin
546552

547553
def settings: ScalaSettings = base.settings
548554
def definitions: Definitions = base.definitions

compiler/src/dotty/tools/dotc/transform/MacroTransform.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ abstract class MacroTransform extends Phase {
3737
ctx.fresh.setTree(tree).setOwner(owner)
3838
}
3939

40-
def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
40+
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
4141
def transformStat(stat: Tree): Tree = stat match {
4242
case _: Import | _: DefTree => transform(stat)
4343
case _ => transform(stat)(using ctx.exprContext(stat, exprOwner))

compiler/src/dotty/tools/dotc/transform/TreeMapWithStages.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap
4646
/** The quotation level of the definition of the locally defined symbol */
4747
protected def levelOf(sym: Symbol): Int = levelOfMap.getOrElse(sym, 0)
4848

49-
/** Localy defined symbols seen so far by `StagingTransformer.transform` */
49+
/** Locally defined symbols seen so far by `StagingTransformer.transform` */
5050
protected def localSymbols: List[Symbol] = enteredSyms
5151

5252
/** If we are inside a quote or a splice */
@@ -74,7 +74,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap
7474
/** Transform the expression splice `splice` which contains the spliced `body`. */
7575
protected def transformSplice(body: Tree, splice: Apply)(using Context): Tree
7676

77-
/** Transform the typee splice `splice` which contains the spliced `body`. */
77+
/** Transform the type splice `splice` which contains the spliced `body`. */
7878
protected def transformSpliceType(body: Tree, splice: Select)(using Context): Tree
7979

8080
override def transform(tree: Tree)(using Context): Tree =
@@ -109,7 +109,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap
109109
try dropEmptyBlocks(quotedTree) match {
110110
case Spliced(t) =>
111111
// '{ $x } --> x
112-
// and adapt the refinment of `Quotes { type tasty: ... } ?=> Expr[T]`
112+
// and adapt the refinement of `Quotes { type reflect: ... } ?=> Expr[T]`
113113
transform(t).asInstance(tree.tpe)
114114
case _ => transformQuotation(quotedTree, tree)
115115
}

tests/pos-macros/i11479/Macro_1.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
trait Foo
2+
given Foo: Foo with {}
3+
inline def summonFoo(): Foo = scala.compiletime.summonInline[Foo]
4+
5+
package p:
6+
trait Bar
7+
given Bar: Bar with {}
8+
inline def summonBar(): Bar = scala.compiletime.summonInline[Bar]
9+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
package p
2+
def test3: Unit = summonBar()

tests/pos-macros/i11479/Test_2.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import p.{*, given}
2+
def test: Unit =
3+
summonFoo()
4+
summonBar()
5+
6+
7+
8+

tests/pos/i11538a.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package a:
2+
3+
trait Printer[A]:
4+
def print(a: A): Unit
5+
6+
given Printer[String] with
7+
def print(s: String) = println(s)
8+
9+
package b:
10+
11+
import a.{given, *}
12+
13+
object test:
14+
import scala.compiletime.{error, summonFrom}
15+
16+
inline def summonStringPrinter =
17+
summonFrom {
18+
case given Printer[String] => ()
19+
case _ => error("Couldn't find a printer")
20+
}
21+
22+
val summoned = summon[Printer[String]]
23+
val summonedFrom = summonStringPrinter

tests/pos/i11538b.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package a:
2+
type Foo
3+
given foo: Foo = ???
4+
5+
import a.{Foo, given}
6+
object test:
7+
inline def summonInlineFoo = scala.compiletime.summonInline[Foo]
8+
val summoned = summon[Foo]
9+
val summonedInline = summonInlineFoo

tests/pos/i11557.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
type MyEncoder
2+
3+
class MyContext:
4+
given intEncoder: MyEncoder = ???
5+
6+
def doEncoding(ctx: MyContext): Unit =
7+
import ctx.{*, given}
8+
summon[MyEncoder]
9+
summonInlineMyEncoder()
10+
11+
inline def summonInlineMyEncoder(): Unit =
12+
compiletime.summonInline[MyEncoder]

0 commit comments

Comments
 (0)