Skip to content

Commit 7488b53

Browse files
committed
Fix scala#4230: Handle import in the REPL correctly
We used to collect user defined imports from the parsed tree and insert them as untyped trees at the top of the REPL wrapper object. This caused members shadowing issues. We now introduce a phase in the REPL compiler that collects imports after type checking and store them as typed tree. We can then create a context with its imports set in the correct order and use it to compile future expressions.
1 parent 6ee3c62 commit 7488b53

File tree

6 files changed

+117
-52
lines changed

6 files changed

+117
-52
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class Compiler {
111111
List(new Flatten, // Lift all inner classes to package scope
112112
new RenameLifted, // Renames lifted classes to local numbering scheme
113113
new TransformWildcards, // Replace wildcards with default values
114-
new MoveStatics, // Move static methods to companion classes
114+
new MoveStatics, // Move static methods from companion to the class itself
115115
new ExpandPrivate, // Widen private definitions accessed from nested classes
116116
new RestoreScopes, // Repair scopes rendered invalid by moving definitions in prior phases of the group
117117
new SelectStatic, // get rid of selects that would be compiled into GetStatic
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package dotty.tools.repl
2+
3+
import dotty.tools.dotc.ast.Trees._
4+
import dotty.tools.dotc.ast.tpd
5+
import dotty.tools.dotc.core.Contexts.Context
6+
import dotty.tools.dotc.core.Phases.Phase
7+
8+
/** A phase that collects user defined top level imports.
9+
*
10+
* These imports must be collected as typed trees and therefore
11+
* after Typer.
12+
*/
13+
class CollectTopLevelImports extends Phase {
14+
import tpd._
15+
16+
def phaseName = "collecttoplevelimports"
17+
18+
private[this] var myImports: List[Import] = _
19+
def imports = myImports
20+
21+
def run(implicit ctx: Context): Unit = {
22+
def topLevelImports(tree: Tree) = {
23+
val PackageDef(_, _ :: TypeDef(_, rhs: Template) :: Nil) = tree
24+
rhs.body.collect { case tree: Import => tree }
25+
}
26+
27+
val tree = ctx.compilationUnit.tpdTree
28+
myImports = topLevelImports(tree)
29+
}
30+
}

compiler/src/dotty/tools/repl/ReplCompiler.scala

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import dotty.tools.dotc.core.Phases.Phase
1414
import dotty.tools.dotc.core.StdNames._
1515
import dotty.tools.dotc.core.Symbols._
1616
import dotty.tools.dotc.reporting.diagnostic.messages
17+
import dotty.tools.dotc.transform.PostTyper
1718
import dotty.tools.dotc.typer.{FrontEnd, ImportInfo}
1819
import dotty.tools.dotc.util.Positions._
1920
import dotty.tools.dotc.util.SourceFile
@@ -26,39 +27,43 @@ import scala.collection.mutable
2627
/** This subclass of `Compiler` replaces the appropriate phases in order to
2728
* facilitate the REPL
2829
*
29-
* Specifically it replaces the front end with `REPLFrontEnd`, and adds a
30-
* custom subclass of `GenBCode`. The custom `GenBCode`, `REPLGenBCode`, works
31-
* in conjunction with a specialized class loader in order to load virtual
32-
* classfiles.
30+
* Specifically it replaces the front end with `REPLFrontEnd`.
3331
*/
3432
class ReplCompiler extends Compiler {
35-
override protected def frontendPhases: List[List[Phase]] =
36-
Phases.replace(classOf[FrontEnd], _ => new REPLFrontEnd :: Nil, super.frontendPhases)
3733

38-
def newRun(initCtx: Context, objectIndex: Int) = new Run(this, initCtx) {
39-
override protected[this] def rootContext(implicit ctx: Context) =
40-
addMagicImports(super.rootContext)
34+
override protected def frontendPhases: List[List[Phase]] = List(
35+
List(new REPLFrontEnd),
36+
List(new CollectTopLevelImports),
37+
List(new PostTyper)
38+
)
4139

42-
private def addMagicImports(initCtx: Context): Context = {
43-
def addImport(path: TermName)(implicit ctx: Context) = {
44-
val importInfo = ImportInfo.rootImport { () =>
45-
ctx.requiredModuleRef(path)
46-
}
47-
ctx.fresh.setNewScope.setImportInfo(importInfo)
48-
}
40+
def newRunContext(initCtx: Context, state: State): Context = {
41+
def addUserDefinedImport(imp: tpd.Import)(implicit ctx: Context) =
42+
ctx.importContext(imp, imp.symbol)
4943

50-
(1 to objectIndex)
51-
.foldLeft(initCtx) { (ictx, i) =>
52-
addImport(nme.EMPTY_PACKAGE ++ "." ++ objectNames(i))(ictx)
53-
}
44+
def importModule(path: TermName)(implicit ctx: Context) = {
45+
val importInfo = ImportInfo.rootImport(() =>
46+
ctx.requiredModuleRef(path))
47+
ctx.fresh.setNewScope.setImportInfo(importInfo)
48+
}
49+
50+
val run = newRun(initCtx.fresh.setReporter(newStoreReporter))
51+
(1 to state.objectIndex).foldLeft(run.runContext) { (ctx, i) =>
52+
// we first import the wrapper object i
53+
val path = nme.EMPTY_PACKAGE ++ "." ++ objectNames(i)
54+
val ctx0 = importModule(path)(ctx)
55+
// then its user defined imports
56+
val imports = state.imports.getOrElse(i, Nil)
57+
if (imports.isEmpty) ctx0
58+
else imports.foldLeft(ctx0.fresh.setNewScope)((ctx, imp) =>
59+
addUserDefinedImport(imp)(ctx))
5460
}
5561
}
5662

5763
private[this] val objectNames = mutable.Map.empty[Int, TermName]
5864
private def objectName(state: State) =
59-
objectNames.getOrElseUpdate(state.objectIndex, {
60-
(str.REPL_SESSION_LINE + state.objectIndex).toTermName
61-
})
65+
objectNames.getOrElseUpdate(state.objectIndex,
66+
(str.REPL_SESSION_LINE + state.objectIndex).toTermName)
6267

6368
private case class Definitions(stats: List[untpd.Tree], state: State)
6469

@@ -86,7 +91,7 @@ class ReplCompiler extends Compiler {
8691
}
8792

8893
Definitions(
89-
state.imports ++ defs,
94+
defs,
9095
state.copy(
9196
objectIndex = state.objectIndex + (if (defs.isEmpty) 0 else 1),
9297
valIndex = valIdx
@@ -130,7 +135,7 @@ class ReplCompiler extends Compiler {
130135

131136
private def runCompilationUnit(unit: CompilationUnit, state: State): Result[(CompilationUnit, State)] = {
132137
val ctx = state.context
133-
ctx.run.compileUnits(unit :: Nil)
138+
ctx.run.compileUnits(unit :: Nil, ctx)
134139

135140
if (!ctx.reporter.hasErrors) (unit, state).result
136141
else ctx.reporter.removeBufferedMessages(ctx).errors
@@ -158,19 +163,18 @@ class ReplCompiler extends Compiler {
158163
def docOf(expr: String)(implicit state: State): Result[String] = {
159164
implicit val ctx: Context = state.context
160165

161-
/**
162-
* Extract the "selected" symbol from `tree`.
166+
/** Extract the "selected" symbol from `tree`.
163167
*
164-
* Because the REPL typechecks an expression, special syntax is needed to get the documentation
165-
* of certain symbols:
168+
* Because the REPL typechecks an expression, special syntax is needed to get the documentation
169+
* of certain symbols:
166170
*
167-
* - To select the documentation of classes, the user needs to pass a call to the class' constructor
168-
* (e.g. `new Foo` to select `class Foo`)
169-
* - When methods are overloaded, the user needs to enter a lambda to specify which functions he wants
170-
* (e.g. `foo(_: Int)` to select `def foo(x: Int)` instead of `def foo(x: String)`
171+
* - To select the documentation of classes, the user needs to pass a call to the class' constructor
172+
* (e.g. `new Foo` to select `class Foo`)
173+
* - When methods are overloaded, the user needs to enter a lambda to specify which functions he wants
174+
* (e.g. `foo(_: Int)` to select `def foo(x: Int)` instead of `def foo(x: String)`
171175
*
172-
* This function returns the right symbol for the received expression, and all the symbols that are
173-
* overridden.
176+
* This function returns the right symbol for the received expression, and all the symbols that are
177+
* overridden.
174178
*/
175179
def extractSymbols(tree: tpd.Tree): Iterator[Symbol] = {
176180
val sym = tree match {
@@ -210,7 +214,7 @@ class ReplCompiler extends Compiler {
210214
import untpd._
211215

212216
val valdef = ValDef("expr".toTermName, TypeTree(), Block(trees, unitLiteral))
213-
val tmpl = Template(emptyConstructor, Nil, EmptyValDef, state.imports :+ valdef)
217+
val tmpl = Template(emptyConstructor, Nil, EmptyValDef, List(valdef))
214218
val wrapper = TypeDef("$wrapper".toTypeName, tmpl)
215219
.withMods(Modifiers(Final))
216220
.withPos(Position(0, expr.length))
@@ -261,9 +265,8 @@ class ReplCompiler extends Compiler {
261265

262266
if (errorsAllowed || !ctx.reporter.hasErrors)
263267
unwrapped(unit.tpdTree, src)
264-
else {
268+
else
265269
ctx.reporter.removeBufferedMessages.errors
266-
}
267270
}
268271
}
269272
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ import scala.collection.JavaConverters._
4343
*
4444
* @param objectIndex the index of the next wrapper
4545
* @param valIndex the index of next value binding for free expressions
46-
* @param imports the list of user defined imports
46+
* @param imports a map from object index to the list of user defined imports
4747
* @param context the latest compiler context
4848
*/
4949
case class State(objectIndex: Int,
5050
valIndex: Int,
51-
imports: List[untpd.Import],
51+
imports: Map[Int, List[tpd.Import]],
5252
context: Context)
5353

5454
/** Main REPL instance, orchestrating input, compilation and presentation */
@@ -63,14 +63,14 @@ class ReplDriver(settings: Array[String],
6363

6464
/** Create a fresh and initialized context with IDE mode enabled */
6565
private[this] def initialCtx = {
66-
val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions).addMode(Mode.Interactive).addMode(Mode.ReadComments)
66+
val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions | Mode.Interactive | Mode.ReadComments)
6767
val ictx = setup(settings, rootCtx)._2
6868
ictx.base.initialize()(ictx)
6969
ictx
7070
}
7171

7272
/** the initial, empty state of the REPL session */
73-
protected[this] def initState = State(0, 0, Nil, rootCtx)
73+
protected[this] def initState = State(0, 0, Map.empty, rootCtx)
7474

7575
/** Reset state of repl to the initial state
7676
*
@@ -140,8 +140,8 @@ class ReplDriver(settings: Array[String],
140140
Console.withOut(out) { Console.withErr(out) { op } }
141141

142142
private def newRun(state: State) = {
143-
val run = compiler.newRun(rootCtx.fresh.setReporter(newStoreReporter), state.objectIndex)
144-
state.copy(context = run.runContext)
143+
val run = compiler.newRunContext(rootCtx, state)
144+
state.copy(context = run)
145145
}
146146

147147
/** Extract possible completions at the index of `cursor` in `expr` */
@@ -173,9 +173,6 @@ class ReplDriver(settings: Array[String],
173173
.getOrElse(Nil)
174174
}
175175

176-
private def extractImports(trees: List[untpd.Tree]): List[untpd.Import] =
177-
trees.collect { case imp: untpd.Import => imp }
178-
179176
private def interpret(res: ParseResult)(implicit state: State): State = {
180177
val newState = res match {
181178
case parsed: Parsed if parsed.trees.nonEmpty =>
@@ -205,6 +202,9 @@ class ReplDriver(settings: Array[String],
205202
case _ => nme.NO_NAME
206203
}
207204

205+
def extractTopLevelImports(ctx: Context): List[tpd.Import] =
206+
ctx.phases.collectFirst { case phase: CollectTopLevelImports => phase.imports }.get
207+
208208
implicit val state = newRun(istate)
209209
compiler
210210
.compile(parsed)
@@ -213,8 +213,11 @@ class ReplDriver(settings: Array[String],
213213
{
214214
case (unit: CompilationUnit, newState: State) =>
215215
val newestWrapper = extractNewestWrapper(unit.untpdTree)
216-
val newImports = newState.imports ++ extractImports(parsed.trees)
217-
val newStateWithImports = newState.copy(imports = newImports)
216+
val newImports = extractTopLevelImports(newState.context)
217+
var allImports = newState.imports
218+
if (newImports.nonEmpty)
219+
allImports += (newState.objectIndex -> newImports)
220+
val newStateWithImports = newState.copy(imports = allImports)
218221

219222
val warnings = newState.context.reporter.removeBufferedMessages(newState.context)
220223
displayErrors(warnings)(newState) // display warnings
@@ -311,7 +314,10 @@ class ReplDriver(settings: Array[String],
311314
initState
312315

313316
case Imports =>
314-
state.imports.foreach(i => out.println(SyntaxHighlighting(i.show(state.context))))
317+
for {
318+
objectIndex <- 1 to state.objectIndex
319+
imp <- state.imports.getOrElse(objectIndex, Nil)
320+
} out.println(imp.show(state.context))
315321
state
316322

317323
case Load(path) =>

compiler/src/dotty/tools/repl/ReplFrontEnd.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import dotc.core.Contexts.Context
1212
* compiler pipeline.
1313
*/
1414
private[repl] class REPLFrontEnd extends FrontEnd {
15-
override def phaseName = "frontend"
1615

1716
override def isRunnable(implicit ctx: Context) = true
1817

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
scala> object A { def f = 1 }
2+
// defined object A
3+
4+
scala> object B { def f = 2 }
5+
// defined object B
6+
7+
scala> import A._
8+
9+
scala> val x0 = f
10+
val x0: Int = 1
11+
12+
scala> import B._
13+
14+
scala> val x1 = f
15+
val x1: Int = 2
16+
17+
scala> def f = 3
18+
def f: Int
19+
20+
scala> val x2 = f
21+
val x2: Int = 3
22+
23+
scala> def f = 4; import A._
24+
def f: Int
25+
26+
scala> val x3 = f
27+
val x3: Int = 1

0 commit comments

Comments
 (0)