Skip to content

Use transformStats to handle imports #11591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TreeMapWithImplicits extends tpd.TreeMap {
* - be tail-recursive where possible
* - don't re-allocate trees where nothing has changed
*/
def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
override def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {

@tailrec def traverse(curStats: List[Tree])(using Context): List[Tree] = {

Expand Down Expand Up @@ -88,8 +88,14 @@ class TreeMapWithImplicits extends tpd.TreeMap {
def localCtx =
if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx
try tree match {
case tree: Block =>
super.transform(tree)(using nestedScopeCtx(tree.stats))
case Block(stats, expr) =>
inContext(nestedScopeCtx(stats)) {
if stats.exists(_.isInstanceOf[Import]) then
// need to transform stats and expr together to account for import visibility
val stats1 = transformStats(stats :+ expr, ctx.owner)
cpy.Block(tree)(stats1.init, stats1.last)
else super.transform(tree)
}
case tree: DefDef =>
inContext(localCtx) {
cpy.DefDef(tree)(
Expand All @@ -100,8 +106,10 @@ class TreeMapWithImplicits extends tpd.TreeMap {
}
case EmptyValDef =>
tree
case _: PackageDef | _: MemberDef =>
case _: MemberDef =>
super.transform(tree)(using localCtx)
case _: PackageDef =>
super.transform(tree)(using ctx.withOwner(tree.symbol.moduleClass))
case impl @ Template(constr, parents, self, _) =>
cpy.Template(tree)(
transformSub(constr),
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class TreeTypeMap(
}
}

override def transformStats(trees: List[tpd.Tree])(using Context): List[Tree] =
override def transformStats(trees: List[tpd.Tree], exprOwner: Symbol)(using Context): List[Tree] =
transformDefs(trees)._2

def transformDefs[TT <: tpd.Tree](trees: List[TT])(using Context): (TreeTypeMap, List[TT]) = {
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ object Trees {
case Assign(lhs, rhs) =>
cpy.Assign(tree)(transform(lhs), transform(rhs))
case Block(stats, expr) =>
cpy.Block(tree)(transformStats(stats), transform(expr))
cpy.Block(tree)(transformStats(stats, ctx.owner), transform(expr))
case If(cond, thenp, elsep) =>
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
case Closure(env, meth, tpt) =>
Expand Down Expand Up @@ -1398,13 +1398,13 @@ object Trees {
cpy.TypeDef(tree)(name, transform(rhs))
}
case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty =>
cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body))
cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body, tree.symbol))
case Import(expr, selectors) =>
cpy.Import(tree)(transform(expr), selectors)
case Export(expr, selectors) =>
cpy.Export(tree)(transform(expr), selectors)
case PackageDef(pid, stats) =>
cpy.PackageDef(tree)(transformSub(pid), transformStats(stats)(using localCtx))
cpy.PackageDef(tree)(transformSub(pid), transformStats(stats, pid.symbol.moduleClass)(using localCtx))
case Annotated(arg, annot) =>
cpy.Annotated(tree)(transform(arg), transform(annot))
case Thicket(trees) =>
Expand All @@ -1416,7 +1416,7 @@ object Trees {
}
}

def transformStats(trees: List[Tree])(using Context): List[Tree] =
def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
transform(trees)
def transform(trees: List[Tree])(using Context): List[Tree] =
flatten(trees mapConserve (transform(_)))
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case ModuleDef(name, impl) =>
cpy.ModuleDef(tree)(name, transformSub(impl))
case tree: DerivingTemplate =>
cpy.Template(tree)(transformSub(tree.constr), transform(tree.parents), transform(tree.derived), transformSub(tree.self), transformStats(tree.body))
cpy.Template(tree)(transformSub(tree.constr), transform(tree.parents),
transform(tree.derived), transformSub(tree.self), transformStats(tree.body, tree.symbol))
case ParsedTry(expr, handler, finalizer) =>
cpy.ParsedTry(tree)(transform(expr), transform(handler), transform(finalizer))
case SymbolLit(str) =>
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,17 @@ object Contexts {
case _ => new Typer
}

override def toString: String = {
def iinfo(using Context) = if (ctx.importInfo == null) "" else i"${ctx.importInfo.selectors}%, %"
"Context(\n" +
(outersIterator.map(ctx => s" owner = ${ctx.owner}, scope = ${ctx.scope}, import = ${iinfo(using ctx)}").mkString("\n"))
}
override def toString: String =
def iinfo(using Context) =
if (ctx.importInfo == null) "" else i"${ctx.importInfo.selectors}%, %"
def cinfo(using Context) =
val core = s" owner = ${ctx.owner}, scope = ${ctx.scope}, import = ${iinfo(using ctx)}"
if (ctx ne NoContext) && (ctx.implicits ne ctx.outer.implicits) then
s"$core, implicits = ${ctx.implicits}"
else
core
s"""Context(
|${outersIterator.map(ctx => cinfo(using ctx)).mkString("\n\n")})""".stripMargin

def settings: ScalaSettings = base.settings
def definitions: Definitions = base.definitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ abstract class MacroTransform extends Phase {
ctx.fresh.setTree(tree).setOwner(owner)
}

def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
def transformStat(stat: Tree): Tree = stat match {
case _: Import | _: DefTree => transform(stat)
case _ => transform(stat)(using ctx.exprContext(stat, exprOwner))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap
/** The quotation level of the definition of the locally defined symbol */
protected def levelOf(sym: Symbol): Int = levelOfMap.getOrElse(sym, 0)

/** Localy defined symbols seen so far by `StagingTransformer.transform` */
/** Locally defined symbols seen so far by `StagingTransformer.transform` */
protected def localSymbols: List[Symbol] = enteredSyms

/** If we are inside a quote or a splice */
Expand Down Expand Up @@ -74,7 +74,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap
/** Transform the expression splice `splice` which contains the spliced `body`. */
protected def transformSplice(body: Tree, splice: Apply)(using Context): Tree

/** Transform the typee splice `splice` which contains the spliced `body`. */
/** Transform the type splice `splice` which contains the spliced `body`. */
protected def transformSpliceType(body: Tree, splice: Select)(using Context): Tree

override def transform(tree: Tree)(using Context): Tree =
Expand Down Expand Up @@ -109,7 +109,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap
try dropEmptyBlocks(quotedTree) match {
case Spliced(t) =>
// '{ $x } --> x
// and adapt the refinment of `Quotes { type tasty: ... } ?=> Expr[T]`
// and adapt the refinement of `Quotes { type reflect: ... } ?=> Expr[T]`
transform(t).asInstance(tree.tpe)
case _ => transformQuotation(quotedTree, tree)
}
Expand Down
9 changes: 9 additions & 0 deletions tests/pos-macros/i11479/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
trait Foo
given Foo: Foo with {}
inline def summonFoo(): Foo = scala.compiletime.summonInline[Foo]

package p:
trait Bar
given Bar: Bar with {}
inline def summonBar(): Bar = scala.compiletime.summonInline[Bar]

2 changes: 2 additions & 0 deletions tests/pos-macros/i11479/OtherTest_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
package p
def test3: Unit = summonBar()
8 changes: 8 additions & 0 deletions tests/pos-macros/i11479/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import p.{*, given}
def test: Unit =
summonFoo()
summonBar()




23 changes: 23 additions & 0 deletions tests/pos/i11538a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package a:

trait Printer[A]:
def print(a: A): Unit

given Printer[String] with
def print(s: String) = println(s)

package b:

import a.{given, *}

object test:
import scala.compiletime.{error, summonFrom}

inline def summonStringPrinter =
summonFrom {
case given Printer[String] => ()
case _ => error("Couldn't find a printer")
}

val summoned = summon[Printer[String]]
val summonedFrom = summonStringPrinter
9 changes: 9 additions & 0 deletions tests/pos/i11538b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package a:
type Foo
given foo: Foo = ???

import a.{Foo, given}
object test:
inline def summonInlineFoo = scala.compiletime.summonInline[Foo]
val summoned = summon[Foo]
val summonedInline = summonInlineFoo
12 changes: 12 additions & 0 deletions tests/pos/i11557.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
type MyEncoder

class MyContext:
given intEncoder: MyEncoder = ???

def doEncoding(ctx: MyContext): Unit =
import ctx.{*, given}
summon[MyEncoder]
summonInlineMyEncoder()

inline def summonInlineMyEncoder(): Unit =
compiletime.summonInline[MyEncoder]