Skip to content

Commit 6f0eafc

Browse files
committed
Fix #4230: Handle import in the REPL correctly
We used to collect user defined imports from the parsed tree and insert them as untyped trees at the top of the REPL wrapper object. This caused members shadowing issues. We now introduce a phase in the REPL compiler that collects imports after type checking and store them as typed tree. We can then create a context with its imports set in the correct order and use it to compile future expressions. This also fixes #4978 by having user defined import in the run context. Auto-completions somehow ignored them when they were part of the untyped tree.
1 parent 13487ca commit 6f0eafc

File tree

8 files changed

+144
-53
lines changed

8 files changed

+144
-53
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class Compiler {
111111
List(new Flatten, // Lift all inner classes to package scope
112112
new RenameLifted, // Renames lifted classes to local numbering scheme
113113
new TransformWildcards, // Replace wildcards with default values
114-
new MoveStatics, // Move static methods to companion classes
114+
new MoveStatics, // Move static methods from companion to the class itself
115115
new ExpandPrivate, // Widen private definitions accessed from nested classes
116116
new RestoreScopes, // Repair scopes rendered invalid by moving definitions in prior phases of the group
117117
new SelectStatic, // get rid of selects that would be compiled into GetStatic

compiler/src/dotty/tools/dotc/Run.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
147147
compileUnits()(ctx)
148148
}
149149

150-
protected def compileUnits()(implicit ctx: Context) = Stats.maybeMonitored {
150+
private def compileUnits()(implicit ctx: Context) = Stats.maybeMonitored {
151151
if (!ctx.mode.is(Mode.Interactive)) // IDEs might have multi-threaded access, accesses are synchronized
152152
ctx.base.checkSingleThreaded()
153153

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package dotty.tools.repl
2+
3+
import dotty.tools.dotc.ast.Trees._
4+
import dotty.tools.dotc.ast.tpd
5+
import dotty.tools.dotc.core.Contexts.Context
6+
import dotty.tools.dotc.core.Phases.Phase
7+
8+
/** A phase that collects user defined top level imports.
9+
*
10+
* These imports must be collected as typed trees and therefore
11+
* after Typer.
12+
*/
13+
class CollectTopLevelImports extends Phase {
14+
import tpd._
15+
16+
def phaseName = "collectTopLevelImports"
17+
18+
private[this] var myImports: List[Import] = _
19+
def imports = myImports
20+
21+
def run(implicit ctx: Context): Unit = {
22+
def topLevelImports(tree: Tree) = {
23+
val PackageDef(_, _ :: TypeDef(_, rhs: Template) :: Nil) = tree
24+
rhs.body.collect { case tree: Import => tree }
25+
}
26+
27+
val tree = ctx.compilationUnit.tpdTree
28+
myImports = topLevelImports(tree)
29+
}
30+
}

compiler/src/dotty/tools/repl/ReplCompiler.scala

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import dotty.tools.dotc.core.Phases.Phase
1414
import dotty.tools.dotc.core.StdNames._
1515
import dotty.tools.dotc.core.Symbols._
1616
import dotty.tools.dotc.reporting.diagnostic.messages
17+
import dotty.tools.dotc.transform.PostTyper
1718
import dotty.tools.dotc.typer.{FrontEnd, ImportInfo}
1819
import dotty.tools.dotc.util.Positions._
1920
import dotty.tools.dotc.util.SourceFile
@@ -23,42 +24,52 @@ import dotty.tools.repl.results._
2324

2425
import scala.collection.mutable
2526

26-
/** This subclass of `Compiler` replaces the appropriate phases in order to
27-
* facilitate the REPL
27+
/** This subclass of `Compiler` is adapted for use in the REPL.
2828
*
29-
* Specifically it replaces the front end with `REPLFrontEnd`, and adds a
30-
* custom subclass of `GenBCode`. The custom `GenBCode`, `REPLGenBCode`, works
31-
* in conjunction with a specialized class loader in order to load virtual
32-
* classfiles.
29+
* - compiles parsed expression in the current REPL state:
30+
* - adds the appropriate imports in scope
31+
* - wraps expressions into a dummy object
32+
* - provides utility to query the type of an expression
33+
* - provides utility to query the documentation of an expression
3334
*/
3435
class ReplCompiler extends Compiler {
35-
override protected def frontendPhases: List[List[Phase]] =
36-
Phases.replace(classOf[FrontEnd], _ => new REPLFrontEnd :: Nil, super.frontendPhases)
3736

38-
def newRun(initCtx: Context, objectIndex: Int) = new Run(this, initCtx) {
39-
override protected[this] def rootContext(implicit ctx: Context) =
40-
addMagicImports(super.rootContext)
41-
42-
private def addMagicImports(initCtx: Context): Context = {
43-
def addImport(path: TermName)(implicit ctx: Context) = {
44-
val importInfo = ImportInfo.rootImport { () =>
45-
ctx.requiredModuleRef(path)
46-
}
47-
ctx.fresh.setNewScope.setImportInfo(importInfo)
37+
override protected def frontendPhases: List[List[Phase]] = List(
38+
List(new REPLFrontEnd),
39+
List(new CollectTopLevelImports),
40+
List(new PostTyper)
41+
)
42+
43+
def newRun(initCtx: Context, state: State): Run = new Run(this, initCtx) {
44+
45+
/** Import previous runs and user defined imports */
46+
override protected[this] def rootContext(implicit ctx: Context): Context = {
47+
def importContext(imp: tpd.Import)(implicit ctx: Context) =
48+
ctx.importContext(imp, imp.symbol)
49+
50+
def importPreviousRun(id: Int)(implicit ctx: Context) = {
51+
// we first import the wrapper object id
52+
val path = nme.EMPTY_PACKAGE ++ "." ++ objectNames(id)
53+
val importInfo = ImportInfo.rootImport(() =>
54+
ctx.requiredModuleRef(path))
55+
val ctx0 = ctx.fresh.setNewScope.setImportInfo(importInfo)
56+
57+
// then its user defined imports
58+
val imports = state.imports.getOrElse(id, Nil)
59+
if (imports.isEmpty) ctx0
60+
else imports.foldLeft(ctx0.fresh.setNewScope)((ctx, imp) =>
61+
importContext(imp)(ctx))
4862
}
4963

50-
(1 to objectIndex)
51-
.foldLeft(initCtx) { (ictx, i) =>
52-
addImport(nme.EMPTY_PACKAGE ++ "." ++ objectNames(i))(ictx)
53-
}
64+
(1 to state.objectIndex).foldLeft(super.rootContext)((ctx, id) =>
65+
importPreviousRun(id)(ctx))
5466
}
5567
}
5668

5769
private[this] val objectNames = mutable.Map.empty[Int, TermName]
5870
private def objectName(state: State) =
59-
objectNames.getOrElseUpdate(state.objectIndex, {
60-
(str.REPL_SESSION_LINE + state.objectIndex).toTermName
61-
})
71+
objectNames.getOrElseUpdate(state.objectIndex,
72+
(str.REPL_SESSION_LINE + state.objectIndex).toTermName)
6273

6374
private case class Definitions(stats: List[untpd.Tree], state: State)
6475

@@ -86,7 +97,7 @@ class ReplCompiler extends Compiler {
8697
}
8798

8899
Definitions(
89-
state.imports ++ defs,
100+
defs,
90101
state.copy(
91102
objectIndex = state.objectIndex + (if (defs.isEmpty) 0 else 1),
92103
valIndex = valIdx
@@ -158,19 +169,18 @@ class ReplCompiler extends Compiler {
158169
def docOf(expr: String)(implicit state: State): Result[String] = {
159170
implicit val ctx: Context = state.context
160171

161-
/**
162-
* Extract the "selected" symbol from `tree`.
172+
/** Extract the "selected" symbol from `tree`.
163173
*
164-
* Because the REPL typechecks an expression, special syntax is needed to get the documentation
165-
* of certain symbols:
174+
* Because the REPL typechecks an expression, special syntax is needed to get the documentation
175+
* of certain symbols:
166176
*
167-
* - To select the documentation of classes, the user needs to pass a call to the class' constructor
168-
* (e.g. `new Foo` to select `class Foo`)
169-
* - When methods are overloaded, the user needs to enter a lambda to specify which functions he wants
170-
* (e.g. `foo(_: Int)` to select `def foo(x: Int)` instead of `def foo(x: String)`
177+
* - To select the documentation of classes, the user needs to pass a call to the class' constructor
178+
* (e.g. `new Foo` to select `class Foo`)
179+
* - When methods are overloaded, the user needs to enter a lambda to specify which functions he wants
180+
* (e.g. `foo(_: Int)` to select `def foo(x: Int)` instead of `def foo(x: String)`
171181
*
172-
* This function returns the right symbol for the received expression, and all the symbols that are
173-
* overridden.
182+
* This function returns the right symbol for the received expression, and all the symbols that are
183+
* overridden.
174184
*/
175185
def extractSymbols(tree: tpd.Tree): Iterator[Symbol] = {
176186
val sym = tree match {
@@ -210,7 +220,7 @@ class ReplCompiler extends Compiler {
210220
import untpd._
211221

212222
val valdef = ValDef("expr".toTermName, TypeTree(), Block(trees, unitLiteral))
213-
val tmpl = Template(emptyConstructor, Nil, EmptyValDef, state.imports :+ valdef)
223+
val tmpl = Template(emptyConstructor, Nil, EmptyValDef, List(valdef))
214224
val wrapper = TypeDef("$wrapper".toTypeName, tmpl)
215225
.withMods(Modifiers(Final))
216226
.withPos(Position(0, expr.length))
@@ -261,9 +271,8 @@ class ReplCompiler extends Compiler {
261271

262272
if (errorsAllowed || !ctx.reporter.hasErrors)
263273
unwrapped(unit.tpdTree, src)
264-
else {
274+
else
265275
ctx.reporter.removeBufferedMessages.errors
266-
}
267276
}
268277
}
269278
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ import scala.collection.JavaConverters._
4343
*
4444
* @param objectIndex the index of the next wrapper
4545
* @param valIndex the index of next value binding for free expressions
46-
* @param imports the list of user defined imports
46+
* @param imports a map from object index to the list of user defined imports
4747
* @param context the latest compiler context
4848
*/
4949
case class State(objectIndex: Int,
5050
valIndex: Int,
51-
imports: List[untpd.Import],
51+
imports: Map[Int, List[tpd.Import]],
5252
context: Context)
5353

5454
/** Main REPL instance, orchestrating input, compilation and presentation */
@@ -63,14 +63,14 @@ class ReplDriver(settings: Array[String],
6363

6464
/** Create a fresh and initialized context with IDE mode enabled */
6565
private[this] def initialCtx = {
66-
val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions).addMode(Mode.Interactive).addMode(Mode.ReadComments)
66+
val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions | Mode.Interactive | Mode.ReadComments)
6767
val ictx = setup(settings, rootCtx)._2
6868
ictx.base.initialize()(ictx)
6969
ictx
7070
}
7171

7272
/** the initial, empty state of the REPL session */
73-
protected[this] def initState = State(0, 0, Nil, rootCtx)
73+
protected[this] def initState = State(0, 0, Map.empty, rootCtx)
7474

7575
/** Reset state of repl to the initial state
7676
*
@@ -140,7 +140,7 @@ class ReplDriver(settings: Array[String],
140140
Console.withOut(out) { Console.withErr(out) { op } }
141141

142142
private def newRun(state: State) = {
143-
val run = compiler.newRun(rootCtx.fresh.setReporter(newStoreReporter), state.objectIndex)
143+
val run = compiler.newRun(rootCtx.fresh.setReporter(newStoreReporter), state)
144144
state.copy(context = run.runContext)
145145
}
146146

@@ -173,9 +173,6 @@ class ReplDriver(settings: Array[String],
173173
.getOrElse(Nil)
174174
}
175175

176-
private def extractImports(trees: List[untpd.Tree]): List[untpd.Import] =
177-
trees.collect { case imp: untpd.Import => imp }
178-
179176
private def interpret(res: ParseResult)(implicit state: State): State = {
180177
val newState = res match {
181178
case parsed: Parsed if parsed.trees.nonEmpty =>
@@ -205,6 +202,9 @@ class ReplDriver(settings: Array[String],
205202
case _ => nme.NO_NAME
206203
}
207204

205+
def extractTopLevelImports(ctx: Context): List[tpd.Import] =
206+
ctx.phases.collectFirst { case phase: CollectTopLevelImports => phase.imports }.get
207+
208208
implicit val state = newRun(istate)
209209
compiler
210210
.compile(parsed)
@@ -213,8 +213,11 @@ class ReplDriver(settings: Array[String],
213213
{
214214
case (unit: CompilationUnit, newState: State) =>
215215
val newestWrapper = extractNewestWrapper(unit.untpdTree)
216-
val newImports = newState.imports ++ extractImports(parsed.trees)
217-
val newStateWithImports = newState.copy(imports = newImports)
216+
val newImports = extractTopLevelImports(newState.context)
217+
var allImports = newState.imports
218+
if (newImports.nonEmpty)
219+
allImports += (newState.objectIndex -> newImports)
220+
val newStateWithImports = newState.copy(imports = allImports)
218221

219222
val warnings = newState.context.reporter.removeBufferedMessages(newState.context)
220223
displayErrors(warnings)(newState) // display warnings
@@ -311,7 +314,10 @@ class ReplDriver(settings: Array[String],
311314
initState
312315

313316
case Imports =>
314-
state.imports.foreach(i => out.println(SyntaxHighlighting(i.show(state.context))))
317+
for {
318+
objectIndex <- 1 to state.objectIndex
319+
imp <- state.imports.getOrElse(objectIndex, Nil)
320+
} out.println(imp.show(state.context))
315321
state
316322

317323
case Load(path) =>

compiler/src/dotty/tools/repl/ReplFrontEnd.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import dotc.core.Contexts.Context
1212
* compiler pipeline.
1313
*/
1414
private[repl] class REPLFrontEnd extends FrontEnd {
15-
override def phaseName = "frontend"
1615

1716
override def isRunnable(implicit ctx: Context) = true
1817

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
scala> object A { def f = 1 }
2+
// defined object A
3+
4+
scala> object B { def f = 2 }
5+
// defined object B
6+
7+
scala> import A._
8+
9+
scala> val x0 = f
10+
val x0: Int = 1
11+
12+
scala> import B._
13+
14+
scala> val x1 = f
15+
val x1: Int = 2
16+
17+
scala> def f = 3
18+
def f: Int
19+
20+
scala> val x2 = f
21+
val x2: Int = 3
22+
23+
scala> def f = 4; import A._
24+
def f: Int
25+
26+
scala> val x3 = f
27+
val x3: Int = 1

compiler/test/dotty/tools/repl/TabcompleteTests.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,24 @@ class TabcompleteTests extends ReplTest {
6060
val expected = List("comp1", "comp2", "comp3")
6161
assertEquals(expected, tabComplete("(new Foo).comp").sorted)
6262
}
63+
64+
@Test def completeFromPreviousState2 =
65+
fromInitialState { implicit state =>
66+
val src = "def hello = 1"
67+
run(src)
68+
}
69+
.andThen { implicit state =>
70+
val expected = List("hello")
71+
assertEquals(expected, tabComplete("hel"))
72+
}
73+
74+
@Test def tabCompleteFromPreviousImport =
75+
fromInitialState { implicit state =>
76+
val src = "import java.io.FileDescriptor"
77+
run(src)
78+
}
79+
.andThen { implicit state =>
80+
val expected = List("FileDescriptor")
81+
assertEquals(expected, tabComplete("val foo: FileDesc"))
82+
}
6383
}

0 commit comments

Comments
 (0)