Skip to content

Commit fc2717c

Browse files
committed
Fix #4230: Handle import in the REPL correctly
We used to collect user defined imports from the parsed tree and insert them as untyped trees at the top of the REPL wrapper object. This caused members shadowing issues. We now introduce a phase in the REPL compiler that collects imports after type checking and store them as typed tree. We can then create a context with its imports set in the correct order and use it to compile future expressions. This also fixes #4978 by having user defined import in the run context. Auto-completions somehow ignored them when they were part of the untyped tree.
1 parent 63b2610 commit fc2717c

File tree

8 files changed

+144
-53
lines changed

8 files changed

+144
-53
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class Compiler {
111111
List(new Flatten, // Lift all inner classes to package scope
112112
new RenameLifted, // Renames lifted classes to local numbering scheme
113113
new TransformWildcards, // Replace wildcards with default values
114-
new MoveStatics, // Move static methods to companion classes
114+
new MoveStatics, // Move static methods from companion to the class itself
115115
new ExpandPrivate, // Widen private definitions accessed from nested classes
116116
new RestoreScopes, // Repair scopes rendered invalid by moving definitions in prior phases of the group
117117
new SelectStatic, // get rid of selects that would be compiled into GetStatic

compiler/src/dotty/tools/dotc/Run.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
147147
compileUnits()(ctx)
148148
}
149149

150-
protected def compileUnits()(implicit ctx: Context) = Stats.maybeMonitored {
150+
private def compileUnits()(implicit ctx: Context) = Stats.maybeMonitored {
151151
if (!ctx.mode.is(Mode.Interactive)) // IDEs might have multi-threaded access, accesses are synchronized
152152
ctx.base.checkSingleThreaded()
153153

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package dotty.tools.repl
2+
3+
import dotty.tools.dotc.ast.Trees._
4+
import dotty.tools.dotc.ast.tpd
5+
import dotty.tools.dotc.core.Contexts.Context
6+
import dotty.tools.dotc.core.Phases.Phase
7+
8+
/** A phase that collects user defined top level imports.
9+
*
10+
* These imports must be collected as typed trees and therefore
11+
* after Typer.
12+
*/
13+
class CollectTopLevelImports extends Phase {
14+
import tpd._
15+
16+
def phaseName = "collectTopLevelImports"
17+
18+
private[this] var myImports: List[Import] = _
19+
def imports = myImports
20+
21+
def run(implicit ctx: Context): Unit = {
22+
def topLevelImports(tree: Tree) = {
23+
val PackageDef(_, _ :: TypeDef(_, rhs: Template) :: Nil) = tree
24+
rhs.body.collect { case tree: Import => tree }
25+
}
26+
27+
val tree = ctx.compilationUnit.tpdTree
28+
myImports = topLevelImports(tree)
29+
}
30+
}

compiler/src/dotty/tools/repl/ReplCompiler.scala

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import dotty.tools.dotc.core.Phases.Phase
1414
import dotty.tools.dotc.core.StdNames._
1515
import dotty.tools.dotc.core.Symbols._
1616
import dotty.tools.dotc.reporting.diagnostic.messages
17+
import dotty.tools.dotc.transform.PostTyper
1718
import dotty.tools.dotc.typer.{FrontEnd, ImportInfo}
1819
import dotty.tools.dotc.util.Positions._
1920
import dotty.tools.dotc.util.SourceFile
@@ -23,42 +24,52 @@ import dotty.tools.repl.results._
2324

2425
import scala.collection.mutable
2526

26-
/** This subclass of `Compiler` replaces the appropriate phases in order to
27-
* facilitate the REPL
27+
/** This subclass of `Compiler` is adapted for use in the REPL.
2828
*
29-
* Specifically it replaces the front end with `REPLFrontEnd`, and adds a
30-
* custom subclass of `GenBCode`. The custom `GenBCode`, `REPLGenBCode`, works
31-
* in conjunction with a specialized class loader in order to load virtual
32-
* classfiles.
29+
* - compiles parsed expression in the current REPL state:
30+
* - adds the appropriate imports in scope
31+
* - wraps expressions into a dummy object
32+
* - provides utility to query the type of an expression
33+
* - provides utility to query the documentation of an expression
3334
*/
3435
class ReplCompiler extends Compiler {
35-
override protected def frontendPhases: List[List[Phase]] =
36-
Phases.replace(classOf[FrontEnd], _ => new REPLFrontEnd :: Nil, super.frontendPhases)
3736

38-
def newRun(initCtx: Context, objectIndex: Int) = new Run(this, initCtx) {
39-
override protected[this] def rootContext(implicit ctx: Context) =
40-
addMagicImports(super.rootContext)
41-
42-
private def addMagicImports(initCtx: Context): Context = {
43-
def addImport(path: TermName)(implicit ctx: Context) = {
44-
val importInfo = ImportInfo.rootImport { () =>
45-
ctx.requiredModuleRef(path)
46-
}
47-
ctx.fresh.setNewScope.setImportInfo(importInfo)
37+
override protected def frontendPhases: List[List[Phase]] = List(
38+
List(new REPLFrontEnd),
39+
List(new CollectTopLevelImports),
40+
List(new PostTyper)
41+
)
42+
43+
def newRun(initCtx: Context, state: State): Run = new Run(this, initCtx) {
44+
45+
/** Import previous runs and user defined imports */
46+
override protected[this] def rootContext(implicit ctx: Context): Context = {
47+
def importContext(imp: tpd.Import)(implicit ctx: Context) =
48+
ctx.importContext(imp, imp.symbol)
49+
50+
def importPreviousRun(id: Int)(implicit ctx: Context) = {
51+
// we first import the wrapper object id
52+
val path = nme.EMPTY_PACKAGE ++ "." ++ objectNames(id)
53+
val importInfo = ImportInfo.rootImport(() =>
54+
ctx.requiredModuleRef(path))
55+
val ctx0 = ctx.fresh.setNewScope.setImportInfo(importInfo)
56+
57+
// then its user defined imports
58+
val imports = state.imports.getOrElse(id, Nil)
59+
if (imports.isEmpty) ctx0
60+
else imports.foldLeft(ctx0.fresh.setNewScope)((ctx, imp) =>
61+
importContext(imp)(ctx))
4862
}
4963

50-
(1 to objectIndex)
51-
.foldLeft(initCtx) { (ictx, i) =>
52-
addImport(nme.EMPTY_PACKAGE ++ "." ++ objectNames(i))(ictx)
53-
}
64+
(1 to state.objectIndex).foldLeft(super.rootContext)((ctx, id) =>
65+
importPreviousRun(id)(ctx))
5466
}
5567
}
5668

5769
private[this] val objectNames = mutable.Map.empty[Int, TermName]
5870
private def objectName(state: State) =
59-
objectNames.getOrElseUpdate(state.objectIndex, {
60-
(str.REPL_SESSION_LINE + state.objectIndex).toTermName
61-
})
71+
objectNames.getOrElseUpdate(state.objectIndex,
72+
(str.REPL_SESSION_LINE + state.objectIndex).toTermName)
6273

6374
private case class Definitions(stats: List[untpd.Tree], state: State)
6475

@@ -86,7 +97,7 @@ class ReplCompiler extends Compiler {
8697
}
8798

8899
Definitions(
89-
state.imports ++ defs,
100+
defs,
90101
state.copy(
91102
objectIndex = state.objectIndex + (if (defs.isEmpty) 0 else 1),
92103
valIndex = valIdx
@@ -158,19 +169,18 @@ class ReplCompiler extends Compiler {
158169
def docOf(expr: String)(implicit state: State): Result[String] = {
159170
implicit val ctx: Context = state.context
160171

161-
/**
162-
* Extract the "selected" symbol from `tree`.
172+
/** Extract the "selected" symbol from `tree`.
163173
*
164-
* Because the REPL typechecks an expression, special syntax is needed to get the documentation
165-
* of certain symbols:
174+
* Because the REPL typechecks an expression, special syntax is needed to get the documentation
175+
* of certain symbols:
166176
*
167-
* - To select the documentation of classes, the user needs to pass a call to the class' constructor
168-
* (e.g. `new Foo` to select `class Foo`)
169-
* - When methods are overloaded, the user needs to enter a lambda to specify which functions he wants
170-
* (e.g. `foo(_: Int)` to select `def foo(x: Int)` instead of `def foo(x: String)`
177+
* - To select the documentation of classes, the user needs to pass a call to the class' constructor
178+
* (e.g. `new Foo` to select `class Foo`)
179+
* - When methods are overloaded, the user needs to enter a lambda to specify which functions he wants
180+
* (e.g. `foo(_: Int)` to select `def foo(x: Int)` instead of `def foo(x: String)`
171181
*
172-
* This function returns the right symbol for the received expression, and all the symbols that are
173-
* overridden.
182+
* This function returns the right symbol for the received expression, and all the symbols that are
183+
* overridden.
174184
*/
175185
def extractSymbols(tree: tpd.Tree): Iterator[Symbol] = {
176186
val sym = tree match {
@@ -210,7 +220,7 @@ class ReplCompiler extends Compiler {
210220
import untpd._
211221

212222
val valdef = ValDef("expr".toTermName, TypeTree(), Block(trees, unitLiteral))
213-
val tmpl = Template(emptyConstructor, Nil, EmptyValDef, state.imports :+ valdef)
223+
val tmpl = Template(emptyConstructor, Nil, EmptyValDef, List(valdef))
214224
val wrapper = TypeDef("$wrapper".toTypeName, tmpl)
215225
.withMods(Modifiers(Final))
216226
.withPos(Position(0, expr.length))
@@ -261,9 +271,8 @@ class ReplCompiler extends Compiler {
261271

262272
if (errorsAllowed || !ctx.reporter.hasErrors)
263273
unwrapped(unit.tpdTree, src)
264-
else {
274+
else
265275
ctx.reporter.removeBufferedMessages.errors
266-
}
267276
}
268277
}
269278
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ import scala.collection.JavaConverters._
4444
*
4545
* @param objectIndex the index of the next wrapper
4646
* @param valIndex the index of next value binding for free expressions
47-
* @param imports the list of user defined imports
47+
* @param imports a map from object index to the list of user defined imports
4848
* @param context the latest compiler context
4949
*/
5050
case class State(objectIndex: Int,
5151
valIndex: Int,
52-
imports: List[untpd.Import],
52+
imports: Map[Int, List[tpd.Import]],
5353
context: Context)
5454

5555
/** Main REPL instance, orchestrating input, compilation and presentation */
@@ -64,14 +64,14 @@ class ReplDriver(settings: Array[String],
6464

6565
/** Create a fresh and initialized context with IDE mode enabled */
6666
private[this] def initialCtx = {
67-
val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions).addMode(Mode.Interactive).addMode(Mode.ReadComments)
67+
val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions | Mode.Interactive | Mode.ReadComments)
6868
val ictx = setup(settings, rootCtx)._2
6969
ictx.base.initialize()(ictx)
7070
ictx
7171
}
7272

7373
/** the initial, empty state of the REPL session */
74-
final def initialState = State(0, 0, Nil, rootCtx)
74+
final def initialState = State(0, 0, Map.empty, rootCtx)
7575

7676
/** Reset state of repl to the initial state
7777
*
@@ -144,7 +144,7 @@ class ReplDriver(settings: Array[String],
144144
Console.withOut(out) { Console.withErr(out) { op } }
145145

146146
private def newRun(state: State) = {
147-
val run = compiler.newRun(rootCtx.fresh.setReporter(newStoreReporter), state.objectIndex)
147+
val run = compiler.newRun(rootCtx.fresh.setReporter(newStoreReporter), state)
148148
state.copy(context = run.runContext)
149149
}
150150

@@ -177,9 +177,6 @@ class ReplDriver(settings: Array[String],
177177
.getOrElse(Nil)
178178
}
179179

180-
private def extractImports(trees: List[untpd.Tree]): List[untpd.Import] =
181-
trees.collect { case imp: untpd.Import => imp }
182-
183180
private def interpret(res: ParseResult)(implicit state: State): State = {
184181
val newState = res match {
185182
case parsed: Parsed if parsed.trees.nonEmpty =>
@@ -209,6 +206,9 @@ class ReplDriver(settings: Array[String],
209206
case _ => nme.NO_NAME
210207
}
211208

209+
def extractTopLevelImports(ctx: Context): List[tpd.Import] =
210+
ctx.phases.collectFirst { case phase: CollectTopLevelImports => phase.imports }.get
211+
212212
implicit val state = newRun(istate)
213213
compiler
214214
.compile(parsed)
@@ -217,8 +217,11 @@ class ReplDriver(settings: Array[String],
217217
{
218218
case (unit: CompilationUnit, newState: State) =>
219219
val newestWrapper = extractNewestWrapper(unit.untpdTree)
220-
val newImports = newState.imports ++ extractImports(parsed.trees)
221-
val newStateWithImports = newState.copy(imports = newImports)
220+
val newImports = extractTopLevelImports(newState.context)
221+
var allImports = newState.imports
222+
if (newImports.nonEmpty)
223+
allImports += (newState.objectIndex -> newImports)
224+
val newStateWithImports = newState.copy(imports = allImports)
222225

223226
val warnings = newState.context.reporter.removeBufferedMessages(newState.context)
224227
displayErrors(warnings)(newState) // display warnings
@@ -315,7 +318,10 @@ class ReplDriver(settings: Array[String],
315318
initialState
316319

317320
case Imports =>
318-
state.imports.foreach(i => out.println(SyntaxHighlighting(i.show(state.context))))
321+
for {
322+
objectIndex <- 1 to state.objectIndex
323+
imp <- state.imports.getOrElse(objectIndex, Nil)
324+
} out.println(imp.show(state.context))
319325
state
320326

321327
case Load(path) =>

compiler/src/dotty/tools/repl/ReplFrontEnd.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import dotc.core.Contexts.Context
1212
* compiler pipeline.
1313
*/
1414
private[repl] class REPLFrontEnd extends FrontEnd {
15-
override def phaseName = "frontend"
1615

1716
override def isRunnable(implicit ctx: Context) = true
1817

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
scala> object A { def f = 1 }
2+
// defined object A
3+
4+
scala> object B { def f = 2 }
5+
// defined object B
6+
7+
scala> import A._
8+
9+
scala> val x0 = f
10+
val x0: Int = 1
11+
12+
scala> import B._
13+
14+
scala> val x1 = f
15+
val x1: Int = 2
16+
17+
scala> def f = 3
18+
def f: Int
19+
20+
scala> val x2 = f
21+
val x2: Int = 3
22+
23+
scala> def f = 4; import A._
24+
def f: Int
25+
26+
scala> val x3 = f
27+
val x3: Int = 1

compiler/test/dotty/tools/repl/TabcompleteTests.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,24 @@ class TabcompleteTests extends ReplTest {
6060
val expected = List("comp1", "comp2", "comp3")
6161
assertEquals(expected, tabComplete("(new Foo).comp").sorted)
6262
}
63+
64+
@Test def completeFromPreviousState2 =
65+
fromInitialState { implicit state =>
66+
val src = "def hello = 1"
67+
run(src)
68+
}
69+
.andThen { implicit state =>
70+
val expected = List("hello")
71+
assertEquals(expected, tabComplete("hel"))
72+
}
73+
74+
@Test def tabCompleteFromPreviousImport =
75+
fromInitialState { implicit state =>
76+
val src = "import java.io.FileDescriptor"
77+
run(src)
78+
}
79+
.andThen { implicit state =>
80+
val expected = List("FileDescriptor")
81+
assertEquals(expected, tabComplete("val foo: FileDesc"))
82+
}
6383
}

0 commit comments

Comments
 (0)