Skip to content

Commit a7a0cd7

Browse files
committed
Improve IDE support for imports
1 parent ac0acf3 commit a7a0cd7

File tree

11 files changed

+447
-88
lines changed

11 files changed

+447
-88
lines changed

compiler/src/dotty/tools/dotc/interactive/Interactive.scala

Lines changed: 100 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import scala.collection._
88
import ast.{NavigateAST, Trees, tpd, untpd}
99
import core._, core.Decorators.{sourcePos => _, _}
1010
import Contexts._, Flags._, Names._, NameOps._, Symbols._, Trees._, Types._
11-
import util.Positions._, util.SourcePosition
11+
import util.Positions._, util.SourceFile, util.SourcePosition
1212
import core.Denotations.SingleDenotation
1313
import NameKinds.SimpleNameKind
1414
import config.Printers.interactiv
@@ -28,6 +28,7 @@ object Interactive {
2828
val references: Int = 4 // include references
2929
val definitions: Int = 8 // include definitions
3030
val linkedClass: Int = 16 // include `symbol.linkedClass`
31+
val imports: Int = 32 // include imports in the results
3132
}
3233

3334
/** Does this tree define a symbol ? */
@@ -59,54 +60,45 @@ object Interactive {
5960
*
6061
* @see sourceSymbol
6162
*/
62-
def enclosingSourceSymbol(path: List[Tree])(implicit ctx: Context): Symbol = {
63-
val sym = path match {
63+
def enclosingSourceSymbols(path: List[Tree], pos: SourcePosition)(implicit ctx: Context): List[Symbol] = {
64+
val syms = path match {
6465
// For a named arg, find the target `DefDef` and jump to the param
6566
case NamedArg(name, _) :: Apply(fn, _) :: _ =>
6667
val funSym = fn.symbol
6768
if (funSym.name == StdNames.nme.copy
6869
&& funSym.is(Synthetic)
6970
&& funSym.owner.is(CaseClass)) {
70-
funSym.owner.info.member(name).symbol
71+
funSym.owner.info.member(name).symbol :: Nil
7172
} else {
7273
val classTree = funSym.topLevelClass.asClass.rootTree
7374
val paramSymbol =
7475
for {
7576
DefDef(_, _, paramss, _, _) <- tpd.defPath(funSym, classTree).lastOption
7677
param <- paramss.flatten.find(_.name == name)
7778
} yield param.symbol
78-
paramSymbol.getOrElse(fn.symbol)
79+
paramSymbol.getOrElse(fn.symbol) :: Nil
7980
}
8081

8182
// For constructor calls, return the `<init>` that was selected
8283
case _ :: (_: New) :: (select: Select) :: _ =>
83-
select.symbol
84+
select.symbol :: Nil
85+
86+
case (_: Thicket) :: (imp: Import) :: _ =>
87+
importedSymbols(imp, _.pos.contains(pos.pos))
88+
89+
case (imp: Import) :: _ =>
90+
importedSymbols(imp, _.pos.contains(pos.pos))
8491

8592
case _ =>
86-
enclosingTree(path).symbol
93+
enclosingTree(path).symbol :: Nil
8794
}
88-
Interactive.sourceSymbol(sym)
89-
}
9095

91-
/**
92-
* The source symbol that is the closest to the path to `pos` in `trees`.
93-
*
94-
* Computes the path from the tree with position `pos` in `trees`, and extract it source
95-
* symbol.
96-
*
97-
* @param trees The trees in which to look for a path to `pos`.
98-
* @param pos That target position of the path.
99-
* @return The source symbol that is the closest to the computed path.
100-
*
101-
* @see sourceSymbol
102-
*/
103-
def enclosingSourceSymbol(trees: List[SourceTree], pos: SourcePosition)(implicit ctx: Context): Symbol = {
104-
enclosingSourceSymbol(pathTo(trees, pos))
96+
syms.map(Interactive.sourceSymbol).filter(_.exists)
10597
}
10698

10799
/** A symbol related to `sym` that is defined in source code.
108100
*
109-
* @see enclosingSourceSymbol
101+
* @see enclosingSourceSymbols
110102
*/
111103
@tailrec def sourceSymbol(sym: Symbol)(implicit ctx: Context): Symbol =
112104
if (!sym.exists)
@@ -304,32 +296,60 @@ object Interactive {
304296
* source code.
305297
*/
306298
def namedTrees(trees: List[SourceTree], include: Include.Set, sym: Symbol)
307-
(implicit ctx: Context): List[SourceTree] =
299+
(implicit ctx: Context): List[SourceNamedTree] =
308300
if (!sym.exists)
309301
Nil
310302
else
311-
namedTrees(trees, (include & Include.references) != 0, matchSymbol(_, sym, include))
303+
namedTrees(trees, include, matchSymbol(_, sym, include))
312304

313305
/** Find named trees with a non-empty position whose name contains `nameSubstring` in `trees`.
314306
*/
315307
def namedTrees(trees: List[SourceTree], nameSubstring: String)
316-
(implicit ctx: Context): List[SourceTree] = {
308+
(implicit ctx: Context): List[SourceNamedTree] = {
317309
val predicate: NameTree => Boolean = _.name.toString.contains(nameSubstring)
318-
namedTrees(trees, includeReferences = false, predicate)
310+
namedTrees(trees, 0, predicate)
319311
}
320312

321313
/** Find named trees with a non-empty position satisfying `treePredicate` in `trees`.
322314
*
323315
* @param includeReferences If true, include references and not just definitions
324316
*/
325-
def namedTrees(trees: List[SourceTree], includeReferences: Boolean, treePredicate: NameTree => Boolean)
326-
(implicit ctx: Context): List[SourceTree] = safely {
327-
val buf = new mutable.ListBuffer[SourceTree]
317+
def namedTrees(trees: List[SourceTree], include: Include.Set, treePredicate: NameTree => Boolean)
318+
(implicit ctx: Context): List[SourceNamedTree] = safely {
319+
val includeReferences = (include & Include.references) != 0
320+
val includeImports = (include & Include.imports) != 0
321+
val buf = new mutable.ListBuffer[SourceNamedTree]
328322

329-
trees foreach { case SourceTree(topTree, source) =>
323+
def traverser(source: SourceFile) = {
330324
new untpd.TreeTraverser {
325+
private def handleImport(imported: List[Symbol],
326+
uexpr: untpd.Tree,
327+
id: untpd.Ident,
328+
rename: Option[untpd.Ident]): Unit = {
329+
val expr = uexpr.asInstanceOf[tpd.Tree]
330+
imported match {
331+
case Nil =>
332+
traverse(expr)
333+
case syms =>
334+
syms.foreach { sym =>
335+
val tree = tpd.Select(expr, sym.name).withPos(id.pos)
336+
val renameTree = rename.map { r =>
337+
val name = if (sym.name.isTypeName) r.name.toTypeName else r.name
338+
RenameTree(name, tpd.Select(expr, sym.name)).withPos(r.pos)
339+
}
340+
renameTree.foreach(traverse)
341+
traverse(tree)
342+
}
343+
}
344+
}
331345
override def traverse(tree: untpd.Tree)(implicit ctx: Context) = {
332346
tree match {
347+
case imp @ Import(uexpr, (id: untpd.Ident) :: Nil) if includeImports =>
348+
val imported = importedSymbols(imp.asInstanceOf[tpd.Import])
349+
handleImport(imported, uexpr, id, None)
350+
case imp @ Import(uexpr, Thicket((id: untpd.Ident) :: (rename: untpd.Ident) :: Nil) :: Nil) if includeImports =>
351+
val imported = importedSymbols(imp.asInstanceOf[tpd.Import])
352+
handleImport(imported, uexpr, id, Some(rename))
333353
case utree: untpd.NameTree if tree.hasType =>
334354
val tree = utree.asInstanceOf[tpd.NameTree]
335355
if (tree.symbol.exists
@@ -338,17 +358,19 @@ object Interactive {
338358
&& !tree.pos.isZeroExtent
339359
&& (includeReferences || isDefinition(tree))
340360
&& treePredicate(tree))
341-
buf += SourceTree(tree, source)
361+
buf += SourceNamedTree(tree, source)
342362
traverseChildren(tree)
343363
case tree: untpd.Inlined =>
344364
traverse(tree.call)
345365
case _ =>
346366
traverseChildren(tree)
347367
}
348368
}
349-
}.traverse(topTree)
369+
}
350370
}
351371

372+
trees.foreach(t => traverser(t.source).traverse(t.tree))
373+
352374
buf.toList
353375
}
354376

@@ -361,9 +383,8 @@ object Interactive {
361383
*/
362384
def findTreesMatching(trees: List[SourceTree],
363385
includes: Include.Set,
364-
symbol: Symbol)(implicit ctx: Context): List[SourceTree] = {
386+
symbol: Symbol)(implicit ctx: Context): List[SourceNamedTree] = {
365387
val linkedSym = symbol.linkedClass
366-
val includeReferences = (includes & Include.references) != 0
367388
val includeDeclaration = (includes & Include.definitions) != 0
368389
val includeLinkedClass = (includes & Include.linkedClass) != 0
369390
val predicate: NameTree => Boolean = tree =>
@@ -377,7 +398,7 @@ object Interactive {
377398
)
378399
)
379400
)
380-
namedTrees(trees, includeReferences, predicate)
401+
namedTrees(trees, includes, predicate)
381402
}
382403

383404
/** The reverse path to the node that closest encloses position `pos`,
@@ -465,10 +486,8 @@ object Interactive {
465486
* @param driver The driver responsible for `path`.
466487
* @return The definitions for the symbol at the end of `path`.
467488
*/
468-
def findDefinitions(path: List[Tree], driver: InteractiveDriver)(implicit ctx: Context): List[SourceTree] = {
469-
val sym = enclosingSourceSymbol(path)
470-
if (sym == NoSymbol) Nil
471-
else {
489+
def findDefinitions(path: List[Tree], pos: SourcePosition, driver: InteractiveDriver)(implicit ctx: Context): List[SourceNamedTree] = {
490+
enclosingSourceSymbols(path, pos).flatMap { sym =>
472491
val enclTree = enclosingTree(path)
473492

474493
val (trees, include) =
@@ -492,4 +511,44 @@ object Interactive {
492511
}
493512
}
494513

514+
/**
515+
* All the symbols that are imported by import statement `imp`, if it matches
516+
* the predicate `selectorPredicate`.
517+
*
518+
* @param imp The import statement to analyze
519+
* @param selectorPredicate A test to find the selector to use.
520+
* @return The symbols imported.
521+
*/
522+
private def importedSymbols(imp: tpd.Import,
523+
selectorPredicate: untpd.Tree => Boolean = util.common.alwaysTrue)
524+
(implicit ctx: Context): List[Symbol] = {
525+
def lookup0(name: Name): Symbol = imp.expr.tpe.member(name).symbol
526+
def lookup(name: Name): List[Symbol] = {
527+
lookup0(name.toTermName) ::
528+
lookup0(name.toTypeName) ::
529+
lookup0(name.moduleClassName) ::
530+
lookup0(name.sourceModuleName) :: Nil
531+
}
532+
533+
val symbols = imp.selectors.find(selectorPredicate) match {
534+
case Some(id: untpd.Ident) =>
535+
lookup(id.name)
536+
case Some(Thicket((id: untpd.Ident) :: (_: untpd.Ident) :: Nil)) =>
537+
lookup(id.name)
538+
case _ => Nil
539+
}
540+
541+
symbols.map(sourceSymbol).filter(_.exists).distinct
542+
}
543+
544+
/**
545+
* Used to represent a renaming import `{foo => bar}`.
546+
* We need this because the name of the tree must be the new name, but the
547+
* denotation must be that of the importee.
548+
*/
549+
private case class RenameTree(name: Name, underlying: Tree) extends NameTree {
550+
override def denot(implicit ctx: Context) = underlying.denot
551+
myTpe = NoType
552+
}
553+
495554
}

compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,19 @@ class InteractiveDriver(val settings: List[String]) extends Driver {
5959
val fromSource = openedTrees.values.flatten.toList
6060
val fromClassPath = (dirClassPathClasses ++ zipClassPathClasses).flatMap { cls =>
6161
val className = cls.toTypeName
62-
List(tree(className, id), tree(className.moduleClassName, id)).flatten
62+
trees(className, id) ::: trees(className.moduleClassName, id)
6363
}
6464
(fromSource ++ fromClassPath).distinct
6565
}
6666

67-
private def tree(className: TypeName, id: String)(implicit ctx: Context): Option[SourceTree] = {
67+
private def trees(className: TypeName, id: String)(implicit ctx: Context): List[SourceTree] = {
6868
val clsd = ctx.base.staticRef(className)
6969
clsd match {
7070
case clsd: ClassDenotation =>
7171
clsd.ensureCompleted()
7272
SourceTree.fromSymbol(clsd.symbol.asClass, id)
7373
case _ =>
74-
None
74+
Nil
7575
}
7676
}
7777

@@ -170,14 +170,16 @@ class InteractiveDriver(val settings: List[String]) extends Driver {
170170
names.toList
171171
}
172172

173-
private def topLevelClassTrees(topTree: Tree, source: SourceFile): List[SourceTree] = {
173+
private def topLevelTrees(topTree: Tree, source: SourceFile): List[SourceTree] = {
174174
val trees = new mutable.ListBuffer[SourceTree]
175175

176176
def addTrees(tree: Tree): Unit = tree match {
177177
case PackageDef(_, stats) =>
178178
stats.foreach(addTrees)
179+
case imp: Import =>
180+
trees += SourceImportTree(imp, source)
179181
case tree: TypeDef =>
180-
trees += SourceTree(tree, source)
182+
trees += SourceNamedTree(tree, source)
181183
case _ =>
182184
}
183185
addTrees(topTree)
@@ -244,7 +246,7 @@ class InteractiveDriver(val settings: List[String]) extends Driver {
244246
val unit = ctx.run.units.head
245247
val t = unit.tpdTree
246248
cleanup(t)
247-
myOpenedTrees(uri) = topLevelClassTrees(t, source)
249+
myOpenedTrees(uri) = topLevelTrees(t, source)
248250
myCompilationUnits(uri) = unit
249251

250252
reporter.removeBufferedMessages

compiler/src/dotty/tools/dotc/interactive/SourceTree.scala

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,24 @@ import core._, core.Decorators.{sourcePos => _}
99
import Contexts._, NameOps._, Symbols._, StdNames._
1010
import util._, util.Positions._
1111

12-
/** A typechecked named `tree` coming from `source` */
13-
case class SourceTree(tree: tpd.NameTree, source: SourceFile) {
12+
/** A `tree` coming from `source` */
13+
sealed trait SourceTree {
14+
15+
/** The underlying tree. */
16+
def tree: tpd.Tree
17+
18+
/** The source from which `tree` comes. */
19+
def source: SourceFile
20+
1421
/** The position of `tree` */
15-
def pos(implicit ctx: Context): SourcePosition = source.atPos(tree.pos)
22+
final def pos(implicit ctx: Context): SourcePosition = source.atPos(tree.pos)
23+
}
24+
25+
/** An import coming from `source` */
26+
case class SourceImportTree(tree: tpd.Import, source: SourceFile) extends SourceTree
27+
28+
/** A typechecked `tree` coming from `source` */
29+
case class SourceNamedTree(tree: tpd.NameTree, source: SourceFile) extends SourceTree {
1630

1731
/** The position of the name in `tree` */
1832
def namePos(implicit ctx: Context): SourcePosition = {
@@ -43,21 +57,33 @@ case class SourceTree(tree: tpd.NameTree, source: SourceFile) {
4357
}
4458

4559
object SourceTree {
46-
def fromSymbol(sym: ClassSymbol, id: String = "")(implicit ctx: Context): Option[SourceTree] = {
60+
def fromSymbol(sym: ClassSymbol, id: String = "")(implicit ctx: Context): List[SourceTree] = {
4761
if (sym == defn.SourceFileAnnot || // FIXME: No SourceFile annotation on SourceFile itself
4862
sym.sourceFile == null) // FIXME: We cannot deal with external projects yet
49-
None
63+
Nil
5064
else {
5165
import ast.Trees._
52-
def sourceTreeOfClass(tree: tpd.Tree): Option[SourceTree] = tree match {
66+
def sourceTreeOfClass(tree: tpd.Tree): Option[SourceNamedTree] = tree match {
5367
case PackageDef(_, stats) =>
5468
stats.flatMap(sourceTreeOfClass).headOption
5569
case tree: tpd.TypeDef if tree.symbol == sym =>
5670
val sourceFile = new SourceFile(sym.sourceFile, Codec.UTF8)
57-
Some(SourceTree(tree, sourceFile))
58-
case _ => None
71+
Some(SourceNamedTree(tree, sourceFile))
72+
case _ =>
73+
None
74+
}
75+
76+
def sourceImports(tree: tpd.Tree, sourceFile: SourceFile): List[SourceImportTree] = tree match {
77+
case PackageDef(_, stats) => stats.flatMap(sourceImports(_, sourceFile))
78+
case imp: tpd.Import => SourceImportTree(imp, sourceFile) :: Nil
79+
case _ => Nil
80+
}
81+
82+
val tree = sym.rootTreeContaining(id)
83+
sourceTreeOfClass(tree) match {
84+
case Some(namedTree) => namedTree :: sourceImports(tree, namedTree.source)
85+
case None => Nil
5986
}
60-
sourceTreeOfClass(sym.rootTreeContaining(id))
6187
}
6288
}
6389
}

0 commit comments

Comments
 (0)