diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index e30180edd145..eb0ad1dc3384 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1147,5 +1147,91 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { case _ => EmptyTree } + + /** + * The symbols that are imported with `expr.name` + * + * @param expr The base of the import statement + * @param name The name that is being imported. + * @return All the symbols that would be imported with `expr.name`. + */ + def importedSymbols(expr: Tree, name: Name)(implicit ctx: Context): List[Symbol] = { + def lookup(name: Name): Symbol = expr.tpe.member(name).symbol + val symbols = + List(lookup(name.toTermName), + lookup(name.toTypeName), + lookup(name.moduleClassName), + lookup(name.sourceModuleName)) + + symbols.map(_.sourceSymbol).filter(_.exists).distinct + } + + /** + * All the symbols that are imported by the first selector of `imp` that matches + * `selectorPredicate`. + * + * @param imp The import statement to analyze + * @param selectorPredicate A test to find the selector to use. + * @return The symbols imported. + */ + def importedSymbols(imp: Import, + selectorPredicate: untpd.Tree => Boolean = util.common.alwaysTrue) + (implicit ctx: Context): List[Symbol] = { + imp.selectors.find(selectorPredicate) match { + case Some(id: untpd.Ident) => + importedSymbols(imp.expr, id.name) + case Some(Thicket((id: untpd.Ident) :: (_: untpd.Ident) :: Nil)) => + importedSymbols(imp.expr, id.name) + case _ => + Nil + } + } + + /** + * The list of select trees that resolve to the same symbols as the ones that are imported + * by `imp`. + */ + def importSelections(imp: Import)(implicit ctx: Context): List[Select] = { + def imported(sym: Symbol, id: untpd.Ident, rename: Option[untpd.Ident]): List[Select] = { + // Give a zero-extent position to the qualifier to prevent it from being included several + // times in results in the language server. + val noPosExpr = focusPositions(imp.expr) + val selectTree = Select(noPosExpr, sym.name).withPos(id.pos) + rename match { + case None => + selectTree :: Nil + case Some(rename) => + // Get the type of the symbol that is actually selected, and construct a select + // node with the new name and the type of the real symbol. + val name = if (sym.name.isTypeName) rename.name.toTypeName else rename.name + val actual = Select(noPosExpr, sym.name) + val renameTree = Select(noPosExpr, name).withPos(rename.pos).withType(actual.tpe) + selectTree :: renameTree :: Nil + } + } + + imp.selectors.flatMap { + case Ident(nme.WILDCARD) => + Nil + case id: untpd.Ident => + importedSymbols(imp.expr, id.name).flatMap { sym => + imported(sym, id, None) + } + case Thicket((id: untpd.Ident) :: (newName: untpd.Ident) :: Nil) => + importedSymbols(imp.expr, id.name).flatMap { sym => + imported(sym, id, Some(newName)) + } + } + } + + /** Replaces all positions in `tree` with zero-extent positions */ + private def focusPositions(tree: Tree)(implicit ctx: Context): Tree = { + val transformer = new tpd.TreeMap { + override def transform(tree: Tree)(implicit ctx: Context): Tree = { + super.transform(tree).withPos(tree.pos.focus) + } + } + transformer.transform(tree) + } } diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index dda067d926a9..d6b6ed1efd1c 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -624,6 +624,26 @@ object Symbols { } } + /** A symbol related to `sym` that is defined in source code. + * + * @see enclosingSourceSymbols + */ + @annotation.tailrec final def sourceSymbol(implicit ctx: Context): Symbol = + if (!denot.exists) + this + else if (denot.is(ModuleVal)) + this.moduleClass.sourceSymbol // The module val always has a zero-extent position + else if (denot.is(Synthetic)) { + val linked = denot.linkedClass + if (linked.exists && !linked.is(Synthetic)) + linked + else + denot.owner.sourceSymbol + } + else if (denot.isPrimaryConstructor) + denot.owner.sourceSymbol + else this + /** The position of this symbol, or NoPosition if the symbol was not loaded * from source or from TASTY. This is always a zero-extent position. * diff --git a/compiler/src/dotty/tools/dotc/interactive/Interactive.scala b/compiler/src/dotty/tools/dotc/interactive/Interactive.scala index 8c0e69afc5a6..b33d6c0407e1 100644 --- a/compiler/src/dotty/tools/dotc/interactive/Interactive.scala +++ b/compiler/src/dotty/tools/dotc/interactive/Interactive.scala @@ -8,7 +8,7 @@ import scala.collection._ import ast.{NavigateAST, Trees, tpd, untpd} import core._, core.Decorators.{sourcePos => _, _} import Contexts._, Flags._, Names._, NameOps._, Symbols._, Trees._, Types._ -import util.Positions._, util.SourcePosition +import util.Positions._, util.SourceFile, util.SourcePosition import core.Denotations.SingleDenotation import NameKinds.SimpleNameKind import config.Printers.interactiv @@ -21,13 +21,46 @@ import StdNames.nme object Interactive { import ast.tpd._ - object Include { // should be an enum, really. - type Set = Int - val overridden: Int = 1 // include trees whose symbol is overridden by `sym` - val overriding: Int = 2 // include trees whose symbol overrides `sym` (but for performance only in same source file) - val references: Int = 4 // include references - val definitions: Int = 8 // include definitions - val linkedClass: Int = 16 // include `symbol.linkedClass` + object Include { + case class Set private (val bits: Int) extends AnyVal { + def | (that: Set): Set = Set(bits | that.bits) + def except(that: Set): Set = Set(bits & ~that.bits) + + def isEmpty: Boolean = bits == 0 + def isOverridden: Boolean = (bits & overridden.bits) != 0 + def isOverriding: Boolean = (bits & overriding.bits) != 0 + def isReferences: Boolean = (bits & references.bits) != 0 + def isDefinitions: Boolean = (bits & definitions.bits) != 0 + def isLinkedClass: Boolean = (bits & linkedClass.bits) != 0 + def isImports: Boolean = (bits & imports.bits) != 0 + } + + /** The empty set */ + val empty: Set = Set(0) + + /** Include trees whose symbol is overridden by `sym` */ + val overridden: Set = Set(1 << 0) + + /** + * Include trees whose symbol overrides `sym` (but for performance only in same source + * file) + */ + val overriding: Set = Set(1 << 1) + + /** Include references */ + val references: Set = Set(1 << 2) + + /** Include definitions */ + val definitions: Set = Set(1 << 3) + + /** Include `sym.linkedClass */ + val linkedClass: Set = Set(1 << 4) + + /** Include imports in the results */ + val imports: Set = Set(1 << 5) + + /** All the flags */ + val all: Set = Set(~0) } /** Does this tree define a symbol ? */ @@ -52,76 +85,52 @@ object Interactive { path.dropWhile(!_.symbol.exists).headOption.getOrElse(tpd.EmptyTree) /** - * The source symbol that is the closest to `path`. + * The source symbols that are the closest to `path`. + * + * If this path ends in an import, then this returns all the symbols that are imported by this + * import statement. * - * @param path The path to the tree whose symbol to extract. - * @return The source symbol that is the closest to `path`. + * @param path The path to the tree whose symbols to extract. + * @return The source symbols that are the closest to `path`. * * @see sourceSymbol */ - def enclosingSourceSymbol(path: List[Tree])(implicit ctx: Context): Symbol = { - val sym = path match { + def enclosingSourceSymbols(path: List[Tree], pos: SourcePosition)(implicit ctx: Context): List[Symbol] = { + val syms = path match { // For a named arg, find the target `DefDef` and jump to the param case NamedArg(name, _) :: Apply(fn, _) :: _ => val funSym = fn.symbol if (funSym.name == StdNames.nme.copy && funSym.is(Synthetic) && funSym.owner.is(CaseClass)) { - funSym.owner.info.member(name).symbol + List(funSym.owner.info.member(name).symbol) } else { val classTree = funSym.topLevelClass.asClass.rootTree - tpd.defPath(funSym, classTree).lastOption.flatMap { - case DefDef(_, _, paramss, _, _) => - paramss.flatten.find(_.name == name).map(_.symbol) - }.getOrElse(fn.symbol) + val paramSymbol = + for { + DefDef(_, _, paramss, _, _) <- tpd.defPath(funSym, classTree).lastOption + param <- paramss.flatten.find(_.name == name) + } yield param.symbol + List(paramSymbol.getOrElse(fn.symbol)) } // For constructor calls, return the `` that was selected case _ :: (_: New) :: (select: Select) :: _ => - select.symbol + List(select.symbol) + + case (_: Thicket) :: (imp: Import) :: _ => + importedSymbols(imp, _.pos.contains(pos.pos)) + + case (imp: Import) :: _ => + importedSymbols(imp, _.pos.contains(pos.pos)) case _ => - enclosingTree(path).symbol + List(enclosingTree(path).symbol) } - Interactive.sourceSymbol(sym) - } - /** - * The source symbol that is the closest to the path to `pos` in `trees`. - * - * Computes the path from the tree with position `pos` in `trees`, and extract it source - * symbol. - * - * @param trees The trees in which to look for a path to `pos`. - * @param pos That target position of the path. - * @return The source symbol that is the closest to the computed path. - * - * @see sourceSymbol - */ - def enclosingSourceSymbol(trees: List[SourceTree], pos: SourcePosition)(implicit ctx: Context): Symbol = { - enclosingSourceSymbol(pathTo(trees, pos)) + syms.map(_.sourceSymbol).filter(_.exists) } - /** A symbol related to `sym` that is defined in source code. - * - * @see enclosingSourceSymbol - */ - @tailrec def sourceSymbol(sym: Symbol)(implicit ctx: Context): Symbol = - if (!sym.exists) - sym - else if (sym.is(ModuleVal)) - sourceSymbol(sym.moduleClass) // The module val always has a zero-extent position - else if (sym.is(Synthetic)) { - val linked = sym.linkedClass - if (linked.exists && !linked.is(Synthetic)) - linked - else - sourceSymbol(sym.owner) - } - else if (sym.isPrimaryConstructor) - sourceSymbol(sym.owner) - else sym - /** Check if `tree` matches `sym`. * This is the case if the symbol defined by `tree` equals `sym`, * or the source symbol of tree equals sym, @@ -134,10 +143,10 @@ object Interactive { sym1.owner.derivesFrom(sym2.owner) && sym1.overriddenSymbol(sym2.owner.asClass) == sym2 ( sym == tree.symbol - || sym.exists && sym == sourceSymbol(tree.symbol) - || include != 0 && sym.name == tree.symbol.name && sym.maybeOwner != tree.symbol.maybeOwner - && ( (include & Include.overridden) != 0 && overrides(sym, tree.symbol) - || (include & Include.overriding) != 0 && overrides(tree.symbol, sym) + || sym.exists && sym == tree.symbol.sourceSymbol + || !include.isEmpty && sym.name == tree.symbol.name && sym.maybeOwner != tree.symbol.maybeOwner + && ( include.isOverridden && overrides(sym, tree.symbol) + || include.isOverriding && overrides(tree.symbol, sym) ) ) } @@ -306,35 +315,40 @@ object Interactive { if (!sym.exists) Nil else - namedTrees(trees, (include & Include.references) != 0, matchSymbol(_, sym, include)) + namedTrees(trees, include, matchSymbol(_, sym, include)) /** Find named trees with a non-empty position whose name contains `nameSubstring` in `trees`. */ def namedTrees(trees: List[SourceTree], nameSubstring: String) (implicit ctx: Context): List[SourceTree] = { val predicate: NameTree => Boolean = _.name.toString.contains(nameSubstring) - namedTrees(trees, includeReferences = false, predicate) + namedTrees(trees, Include.empty, predicate) } /** Find named trees with a non-empty position satisfying `treePredicate` in `trees`. * * @param includeReferences If true, include references and not just definitions */ - def namedTrees(trees: List[SourceTree], includeReferences: Boolean, treePredicate: NameTree => Boolean) + def namedTrees(trees: List[SourceTree], include: Include.Set, treePredicate: NameTree => Boolean) (implicit ctx: Context): List[SourceTree] = safely { val buf = new mutable.ListBuffer[SourceTree] - trees foreach { case SourceTree(topTree, source) => + def traverser(source: SourceFile) = { new untpd.TreeTraverser { override def traverse(tree: untpd.Tree)(implicit ctx: Context) = { tree match { + case imp: untpd.Import if include.isImports && tree.hasType => + val tree = imp.asInstanceOf[tpd.Import] + val selections = tpd.importSelections(tree) + traverse(imp.expr) + selections.foreach(traverse) case utree: untpd.NameTree if tree.hasType => val tree = utree.asInstanceOf[tpd.NameTree] if (tree.symbol.exists && !tree.symbol.is(Synthetic) && tree.pos.exists && !tree.pos.isZeroExtent - && (includeReferences || isDefinition(tree)) + && (include.isReferences || isDefinition(tree)) && treePredicate(tree)) buf += SourceTree(tree, source) traverseChildren(tree) @@ -344,38 +358,40 @@ object Interactive { traverseChildren(tree) } } - }.traverse(topTree) + } } + trees.foreach(t => traverser(t.source).traverse(t.tree)) + buf.toList } /** * Find trees that match `symbol` in `trees`. * - * @param trees The trees to inspect. - * @param includes Whether to include references, definitions, etc. - * @param symbol The symbol for which we want to find references. + * @param trees The trees to inspect. + * @param includes Whether to include references, definitions, etc. + * @param symbol The symbol for which we want to find references. + * @param predicate An additional predicate that the trees must match. */ def findTreesMatching(trees: List[SourceTree], includes: Include.Set, - symbol: Symbol)(implicit ctx: Context): List[SourceTree] = { + symbol: Symbol, + predicate: NameTree => Boolean = util.common.alwaysTrue + )(implicit ctx: Context): List[SourceTree] = { val linkedSym = symbol.linkedClass - val includeReferences = (includes & Include.references) != 0 - val includeDeclaration = (includes & Include.definitions) != 0 - val includeLinkedClass = (includes & Include.linkedClass) != 0 - val predicate: NameTree => Boolean = tree => + val fullPredicate: NameTree => Boolean = tree => ( !tree.symbol.isPrimaryConstructor - && (includeDeclaration || !Interactive.isDefinition(tree)) + && (includes.isDefinitions || !Interactive.isDefinition(tree)) && ( Interactive.matchSymbol(tree, symbol, includes) - || ( includeDeclaration - && includeLinkedClass + || ( includes.isLinkedClass && linkedSym.exists && Interactive.matchSymbol(tree, linkedSym, includes) ) ) + && predicate(tree) ) - namedTrees(trees, includeReferences, predicate) + namedTrees(trees, includes, fullPredicate) } /** The reverse path to the node that closest encloses position `pos`, @@ -463,10 +479,8 @@ object Interactive { * @param driver The driver responsible for `path`. * @return The definitions for the symbol at the end of `path`. */ - def findDefinitions(path: List[Tree], driver: InteractiveDriver)(implicit ctx: Context): List[SourceTree] = { - val sym = enclosingSourceSymbol(path) - if (sym == NoSymbol) Nil - else { + def findDefinitions(path: List[Tree], pos: SourcePosition, driver: InteractiveDriver)(implicit ctx: Context): List[SourceTree] = { + enclosingSourceSymbols(path, pos).flatMap { sym => val enclTree = enclosingTree(path) val (trees, include) = @@ -483,7 +497,7 @@ object Interactive { } (trees, Include.definitions | Include.overriding) case _ => - (Nil, 0) + (Nil, Include.empty) } findTreesMatching(trees, include, sym) @@ -541,4 +555,21 @@ object Interactive { } } + /** + * Is this tree using a renaming introduced by an import statement or an alias for `this`? + * + * @param tree The tree to inspect + * @return True, if this tree's name is different than its symbol's name, indicating that + * it uses a renaming introduced by an import statement or an alias for `this`. + */ + def isRenamed(tree: NameTree)(implicit ctx: Context): Boolean = { + val symbol = tree.symbol + symbol.exists && !sameName(tree.name, symbol.name) + } + + /** Are the two names the same? */ + def sameName(n0: Name, n1: Name): Boolean = { + n0.stripModuleClassSuffix.toTermName eq n1.stripModuleClassSuffix.toTermName + } + } diff --git a/compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala b/compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala index 0e4c38fddf30..3bb3a53ef632 100644 --- a/compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala +++ b/compiler/src/dotty/tools/dotc/interactive/InteractiveDriver.scala @@ -162,7 +162,7 @@ class InteractiveDriver(val settings: List[String]) extends Driver { val unit = ctx.run.units.head val t = unit.tpdTree cleanup(t) - myOpenedTrees(uri) = topLevelClassTrees(t, source) + myOpenedTrees(uri) = topLevelTrees(t, source) myCompilationUnits(uri) = unit reporter.removeBufferedMessages @@ -187,17 +187,17 @@ class InteractiveDriver(val settings: List[String]) extends Driver { * @see SourceTree.fromSymbol */ private def treesFromClassName(className: TypeName, id: String)(implicit ctx: Context): List[SourceTree] = { - def tree(className: TypeName, id: String): Option[SourceTree] = { + def trees(className: TypeName, id: String): List[SourceTree] = { val clsd = ctx.base.staticRef(className) clsd match { case clsd: ClassDenotation => clsd.ensureCompleted() SourceTree.fromSymbol(clsd.symbol.asClass, id) case _ => - None + Nil } } - List(tree(className, id), tree(className.moduleClassName, id)).flatten + trees(className, id) ::: trees(className.moduleClassName, id) } // FIXME: classfiles in directories may change at any point, so we retraverse @@ -246,12 +246,14 @@ class InteractiveDriver(val settings: List[String]) extends Driver { } } - private def topLevelClassTrees(topTree: Tree, source: SourceFile): List[SourceTree] = { + private def topLevelTrees(topTree: Tree, source: SourceFile): List[SourceTree] = { val trees = new mutable.ListBuffer[SourceTree] def addTrees(tree: Tree): Unit = tree match { case PackageDef(_, stats) => stats.foreach(addTrees) + case imp: Import => + trees += SourceTree(imp, source) case tree: TypeDef => trees += SourceTree(tree, source) case _ => diff --git a/compiler/src/dotty/tools/dotc/interactive/SourceTree.scala b/compiler/src/dotty/tools/dotc/interactive/SourceTree.scala index 6d197d38077a..2ceda42acf5f 100644 --- a/compiler/src/dotty/tools/dotc/interactive/SourceTree.scala +++ b/compiler/src/dotty/tools/dotc/interactive/SourceTree.scala @@ -9,44 +9,52 @@ import core._, core.Decorators.{sourcePos => _} import Contexts._, NameOps._, Symbols._, StdNames._ import util._, util.Positions._ -/** A typechecked named `tree` coming from `source` */ -case class SourceTree(tree: tpd.NameTree, source: SourceFile) { +/** + * A `tree` coming from `source` + * + * `tree` can be either an `Import` or a `NameTree`. + */ +case class SourceTree(tree: tpd.Tree /** really: tpd.Import | tpd.NameTree */, source: SourceFile) { + /** The position of `tree` */ - def pos(implicit ctx: Context): SourcePosition = source.atPos(tree.pos) + final def pos(implicit ctx: Context): SourcePosition = source.atPos(tree.pos) /** The position of the name in `tree` */ - def namePos(implicit ctx: Context): SourcePosition = { - // FIXME: Merge with NameTree#namePos ? - val treePos = tree.pos - if (treePos.isZeroExtent || tree.name.toTermName == nme.ERROR) - NoSourcePosition - else { - // Constructors are named `` in the trees, but `this` in the source. - val nameLength = tree.name match { - case nme.CONSTRUCTOR => nme.this_.toString.length - case other => other.stripModuleClassSuffix.show.toString.length + def namePos(implicit ctx: Context): SourcePosition = tree match { + case tree: tpd.NameTree => + // FIXME: Merge with NameTree#namePos ? + val treePos = tree.pos + if (treePos.isZeroExtent || tree.name.toTermName == nme.ERROR) + NoSourcePosition + else { + // Constructors are named `` in the trees, but `this` in the source. + val nameLength = tree.name match { + case nme.CONSTRUCTOR => nme.this_.toString.length + case other => other.stripModuleClassSuffix.show.toString.length + } + val position = { + // FIXME: This is incorrect in some cases, like with backquoted identifiers, + // see https://github.com/lampepfl/dotty/pull/1634#issuecomment-257079436 + val (start, end) = + if (!treePos.isSynthetic) + (treePos.point, treePos.point + nameLength) + else + // If we don't have a point, we need to find it + (treePos.end - nameLength, treePos.end) + Position(start, end, start) + } + source.atPos(position) } - val position = { - // FIXME: This is incorrect in some cases, like with backquoted identifiers, - // see https://github.com/lampepfl/dotty/pull/1634#issuecomment-257079436 - val (start, end) = - if (!treePos.isSynthetic) - (treePos.point, treePos.point + nameLength) - else - // If we don't have a point, we need to find it - (treePos.end - nameLength, treePos.end) - Position(start, end, start) - } - source.atPos(position) - } + case _ => + NoSourcePosition } } object SourceTree { - def fromSymbol(sym: ClassSymbol, id: String = "")(implicit ctx: Context): Option[SourceTree] = { + def fromSymbol(sym: ClassSymbol, id: String = "")(implicit ctx: Context): List[SourceTree] = { if (sym == defn.SourceFileAnnot || // FIXME: No SourceFile annotation on SourceFile itself sym.sourceFile == null) // FIXME: We cannot deal with external projects yet - None + Nil else { import ast.Trees._ def sourceTreeOfClass(tree: tpd.Tree): Option[SourceTree] = tree match { @@ -55,9 +63,21 @@ object SourceTree { case tree: tpd.TypeDef if tree.symbol == sym => val sourceFile = new SourceFile(sym.sourceFile, Codec.UTF8) Some(SourceTree(tree, sourceFile)) - case _ => None + case _ => + None + } + + def sourceImports(tree: tpd.Tree, sourceFile: SourceFile): List[SourceTree] = tree match { + case PackageDef(_, stats) => stats.flatMap(sourceImports(_, sourceFile)) + case imp: tpd.Import => SourceTree(imp, sourceFile) :: Nil + case _ => Nil + } + + val tree = sym.rootTreeContaining(id) + sourceTreeOfClass(tree) match { + case Some(namedTree) => namedTree :: sourceImports(tree, namedTree.source) + case None => Nil } - sourceTreeOfClass(sym.rootTreeContaining(id)) } } } diff --git a/language-server/src/dotty/tools/languageserver/DottyLanguageServer.scala b/language-server/src/dotty/tools/languageserver/DottyLanguageServer.scala index 0712e0f3029c..6fe1b228cb6a 100644 --- a/language-server/src/dotty/tools/languageserver/DottyLanguageServer.scala +++ b/language-server/src/dotty/tools/languageserver/DottyLanguageServer.scala @@ -301,7 +301,7 @@ class DottyLanguageServer extends LanguageServer val pos = sourcePosition(driver, uri, params.getPosition) val path = Interactive.pathTo(driver.openedTrees(uri), pos) - val definitions = Interactive.findDefinitions(path, driver).toList + val definitions = Interactive.findDefinitions(path, pos, driver).toList definitions.flatMap(d => location(d.namePos, positionMapperFor(d.source))).asJava } @@ -311,34 +311,38 @@ class DottyLanguageServer extends LanguageServer val includes = { val includeDeclaration = params.getContext.isIncludeDeclaration - Include.references | Include.overriding | (if (includeDeclaration) Include.definitions else 0) + Include.references | Include.overriding | Include.imports | + (if (includeDeclaration) Include.definitions else Include.empty) } + val uriTrees = driver.openedTrees(uri) val pos = sourcePosition(driver, uri, params.getPosition) - val (definitions, originalSymbol, originalSymbolName) = { + val (definitions, originalSymbols) = { implicit val ctx: Context = driver.currentCtx val path = Interactive.pathTo(driver.openedTrees(uri), pos) - val originalSymbol = Interactive.enclosingSourceSymbol(path) - val originalSymbolName = originalSymbol.name.sourceModuleName.toString - val definitions = Interactive.findDefinitions(path, driver) + val definitions = Interactive.findDefinitions(path, pos, driver) + val originalSymbols = Interactive.enclosingSourceSymbols(path, pos) - (definitions, originalSymbol, originalSymbolName) + (definitions, originalSymbols) } val references = { // Collect the information necessary to look into each project separately: representation of // `originalSymbol` in this project, the context and correct Driver. - val perProjectInfo = inProjectsSeeing(driver, definitions, originalSymbol) - - perProjectInfo.flatMap { (remoteDriver, ctx, definition) => - val trees = remoteDriver.sourceTreesContaining(originalSymbolName)(ctx) - val matches = Interactive.findTreesMatching(trees, includes, definition)(ctx) - matches.map(tree => location(tree.namePos(ctx), positionMapperFor(tree.source))) + val perProjectInfo = inProjectsSeeing(driver, definitions, originalSymbols) + + perProjectInfo.flatMap { (remoteDriver, ctx, definitions) => + definitions.flatMap { definition => + val name = definition.name(ctx).sourceModuleName.toString + val trees = remoteDriver.sourceTreesContaining(name)(ctx) + val matches = Interactive.findTreesMatching(trees, includes, definition)(ctx) + matches.map(tree => location(tree.namePos(ctx), positionMapperFor(tree.source))) + } } }.toList - references.flatten.asJava + references.flatten.distinct.asJava } override def rename(params: RenameParams) = computeAsync { cancelToken => @@ -346,25 +350,60 @@ class DottyLanguageServer extends LanguageServer val driver = driverFor(uri) implicit val ctx = driver.currentCtx + val uriTrees = driver.openedTrees(uri) val pos = sourcePosition(driver, uri, params.getPosition) - val sym = Interactive.enclosingSourceSymbol(driver.openedTrees(uri), pos) + val path = Interactive.pathTo(uriTrees, pos) + val syms = Interactive.enclosingSourceSymbols(path, pos) + val newName = params.getNewName + + def findRenamedReferences(trees: List[SourceTree], syms: List[Symbol], withName: Name): List[SourceTree] = { + val includes = Include.all + syms.flatMap { sym => + Interactive.findTreesMatching(trees, Include.all, sym, t => Interactive.sameName(t.name, withName)) + } + } - if (sym == NoSymbol) new WorkspaceEdit() - else { - val trees = driver.allTreesContaining(sym.name.sourceModuleName.toString) - val newName = params.getNewName - val includes = - Include.references | Include.definitions | Include.linkedClass | Include.overriding - val refs = Interactive.findTreesMatching(trees, includes, sym) + val refs = + path match { + // Selected a renaming in an import node + case Thicket(_ :: (rename: Ident) :: Nil) :: (_: Import) :: rest if rename.pos.contains(pos.pos) => + findRenamedReferences(uriTrees, syms, rename.name) + + // Selected a reference that has been renamed + case (nameTree: NameTree) :: rest if Interactive.isRenamed(nameTree) => + findRenamedReferences(uriTrees, syms, nameTree.name) + + case _ => + val (include, allSymbols) = + if (syms.exists(_.allOverriddenSymbols.nonEmpty)) { + showMessageRequest(MessageType.Info, + RENAME_OVERRIDDEN_QUESTION, + List( + RENAME_OVERRIDDEN -> (() => (Include.all, syms.flatMap(s => s :: s.allOverriddenSymbols.toList))), + RENAME_NO_OVERRIDDEN -> (() => (Include.all.except(Include.overridden), syms))) + ).get.getOrElse((Include.empty, List.empty[Symbol])) + } else { + (Include.all, syms) + } + + val names = allSymbols.map(_.name.sourceModuleName).toSet + val trees = names.flatMap(name => driver.allTreesContaining(name.toString)).toList + allSymbols.flatMap { sym => + Interactive.findTreesMatching(trees, + include, + sym, + t => names.exists(Interactive.sameName(t.name, _))) + } + } - val changes = refs.groupBy(ref => toUriOption(ref.source)) + val changes = + refs.groupBy(ref => toUriOption(ref.source)) .flatMap((uriOpt, ref) => uriOpt.map(uri => (uri.toString, ref))) .mapValues(refs => refs.flatMap(ref => - range(ref.namePos, positionMapperFor(ref.source)).map(nameRange => new TextEdit(nameRange, newName))).asJava) + range(ref.namePos, positionMapperFor(ref.source)).map(nameRange => new TextEdit(nameRange, newName))).distinct.asJava) - new WorkspaceEdit(changes.asJava) - } + new WorkspaceEdit(changes.asJava) } override def documentHighlight(params: TextDocumentPositionParams) = computeAsync { cancelToken => @@ -374,16 +413,17 @@ class DottyLanguageServer extends LanguageServer val pos = sourcePosition(driver, uri, params.getPosition) val uriTrees = driver.openedTrees(uri) - val sym = Interactive.enclosingSourceSymbol(uriTrees, pos) + val path = Interactive.pathTo(uriTrees, pos) + val syms = Interactive.enclosingSourceSymbols(path, pos) + val includes = Include.all.except(Include.linkedClass) - if (sym == NoSymbol) Nil.asJava - else { - val refs = Interactive.namedTrees(uriTrees, Include.references | Include.overriding, sym) + syms.flatMap { sym => + val refs = Interactive.findTreesMatching(uriTrees, includes, sym) (for { - ref <- refs if !ref.tree.symbol.isPrimaryConstructor + ref <- refs nameRange <- range(ref.namePos, positionMapperFor(ref.source)) - } yield new DocumentHighlight(nameRange, DocumentHighlightKind.Read)).asJava - } + } yield new DocumentHighlight(nameRange, DocumentHighlightKind.Read)) + }.distinct.asJava } override def hover(params: TextDocumentPositionParams) = computeAsync { cancelToken => @@ -393,15 +433,20 @@ class DottyLanguageServer extends LanguageServer val pos = sourcePosition(driver, uri, params.getPosition) val trees = driver.openedTrees(uri) + val path = Interactive.pathTo(trees, pos) val tp = Interactive.enclosingType(trees, pos) val tpw = tp.widenTermRefExpr if (tp.isError || tpw == NoType) null // null here indicates that no response should be sent else { - val symbol = Interactive.enclosingSourceSymbol(trees, pos) - val docComment = ParsedComment.docOf(symbol) - val content = hoverContent(Some(tpw.show), docComment) - new Hover(content, null) + Interactive.enclosingSourceSymbols(path, pos) match { + case Nil => + null + case symbols => + val docComments = symbols.flatMap(ParsedComment.docOf) + val content = hoverContent(Some(tpw.show), docComments) + new Hover(content, null) + } } } @@ -412,7 +457,7 @@ class DottyLanguageServer extends LanguageServer val uriTrees = driver.openedTrees(uri) - val defs = Interactive.namedTrees(uriTrees, includeReferences = false, _ => true) + val defs = Interactive.namedTrees(uriTrees, Include.empty, _ => true) (for { d <- defs if !isWorksheetWrapper(d) info <- symbolInfo(d.tree.symbol, d.namePos, positionMapperFor(d.source)) @@ -437,21 +482,24 @@ class DottyLanguageServer extends LanguageServer val pos = sourcePosition(driver, uri, params.getPosition) - val (definitions, originalSymbol) = { + val (definitions, originalSymbols) = { implicit val ctx: Context = driver.currentCtx val path = Interactive.pathTo(driver.openedTrees(uri), pos) - val originalSymbol = Interactive.enclosingSourceSymbol(path) - val definitions = Interactive.findDefinitions(path, driver) - (definitions, originalSymbol) + val originalSymbols = Interactive.enclosingSourceSymbols(path, pos) + val definitions = Interactive.findDefinitions(path, pos, driver) + (definitions, originalSymbols) } val implementations = { - val perProjectInfo = inProjectsSeeing(driver, definitions, originalSymbol) + val perProjectInfo = inProjectsSeeing(driver, definitions, originalSymbols) - perProjectInfo.flatMap { (remoteDriver, ctx, definition) => + perProjectInfo.flatMap { (remoteDriver, ctx, definitions) => val trees = remoteDriver.sourceTrees(ctx) - val predicate = Interactive.implementationFilter(definition)(ctx) - val matches = Interactive.namedTrees(trees, includeReferences = false, predicate)(ctx) + val predicate: NameTree => Boolean = { + val predicates = definitions.map(Interactive.implementationFilter(_)(ctx)) + tree => predicates.exists(_(tree)) + } + val matches = Interactive.namedTrees(trees, Include.empty, predicate)(ctx) matches.map(tree => location(tree.namePos(ctx), positionMapperFor(tree.source))) } }.toList @@ -508,32 +556,60 @@ class DottyLanguageServer extends LanguageServer } /** - * Finds projects that can see any of `definitions`, translate `symbol` in their universe. + * Finds projects that can see any of `definitions`, translate `symbols` in their universe. * * @param baseDriver The driver responsible for the trees in `definitions` and `symbol`. * @param definitions The definitions to consider when looking for projects. - * @param symbol A symbol to translate in the universes of the remote projects. + * @param symbol Symbols to translate in the universes of the remote projects. * @return A list consisting of the remote drivers, their context, and the translation of `symbol` * into their universe. */ private def inProjectsSeeing(baseDriver: InteractiveDriver, definitions: List[SourceTree], - symbol: Symbol): List[(InteractiveDriver, Context, Symbol)] = { + symbols: List[Symbol]): List[(InteractiveDriver, Context, List[Symbol])] = { val projects = projectsSeeing(definitions)(baseDriver.currentCtx) projects.toList.map { config => val remoteDriver = drivers(config) val ctx = remoteDriver.currentCtx - val definition = Interactive.localize(symbol, baseDriver, remoteDriver) - (remoteDriver, ctx, definition) + val definitions = symbols.map(Interactive.localize(_, baseDriver, remoteDriver)) + (remoteDriver, ctx, definitions) } } + /** + * Send a `window/showMessageRequest` to the client, asking to choose between `choices`, and + * perform the associated operation. + * + * @param tpe The type of the request + * @param message The message accompanying the request + * @param choices The choices and their associated operation + * @return A future that will complete with the result of executing the action corresponding to + * the user's response. + */ + private def showMessageRequest[T](tpe: MessageType, + message: String, + choices: List[(String, () => T)]): CompletableFuture[Option[T]] = { + val options = choices.map((title, _) => new MessageActionItem(title)) + val request = new ShowMessageRequestParams(options.asJava) + request.setMessage(message) + request.setType(tpe) + + client.showMessageRequest(request).thenApply { (answer: MessageActionItem) => + choices.find(_._1 == answer.getTitle).map { + case (_, action) => action() + } + } + } } object DottyLanguageServer { /** Configuration file normally generated by sbt-dotty */ final val IDE_CONFIG_FILE = ".dotty-ide.json" + final val RENAME_OVERRIDDEN_QUESTION = "Do you want to rename the base member, or only this member?" + final val RENAME_OVERRIDDEN= "Rename the base member" + final val RENAME_NO_OVERRIDDEN = "Rename only this member" + /** Convert an lsp4j.Position to a SourcePosition */ def sourcePosition(driver: InteractiveDriver, uri: URI, pos: lsp4j.Position): SourcePosition = { val actualPosition = @@ -734,7 +810,7 @@ object DottyLanguageServer { } private def hoverContent(typeInfo: Option[String], - comment: Option[ParsedComment] + comments: List[ParsedComment] )(implicit ctx: Context): lsp4j.MarkupContent = { val buf = new StringBuilder typeInfo.foreach { info => @@ -743,8 +819,7 @@ object DottyLanguageServer { |``` |""".stripMargin) } - - comment.foreach { comment => + comments.foreach { comment => buf.append(comment.renderAsMarkdown) } diff --git a/language-server/test/dotty/tools/languageserver/DefinitionTest.scala b/language-server/test/dotty/tools/languageserver/DefinitionTest.scala index b215a7a277c3..fa71b14384f7 100644 --- a/language-server/test/dotty/tools/languageserver/DefinitionTest.scala +++ b/language-server/test/dotty/tools/languageserver/DefinitionTest.scala @@ -249,4 +249,77 @@ class DefinitionTest { .definition(m7 to m8, List(m3 to m4)) } + @Test def goToDefinitionImport: Unit = { + withSources( + code"""package a + class ${m1}Foo${m2}""", + code"""package b + import a.${m3}Foo${m4} + class C extends ${m5}Foo${m6}""" + ).definition(m1 to m2, List(m1 to m2)) + .definition(m3 to m4, List(m1 to m2)) + .definition(m5 to m6, List(m1 to m2)) + } + + @Test def goToDefinitionRenamedImport: Unit = { + withSources( + code"""package a + class ${m1}Foo${m2}""", + code"""package b + import a.{${m3}Foo${m4} => ${m5}Bar${m6}} + class C extends ${m7}Bar${m8}""" + ).definition(m1 to m2, List(m1 to m2)) + .definition(m3 to m4, List(m1 to m2)) + .definition(m5 to m6, List(m1 to m2)) + .definition(m7 to m8, List(m1 to m2)) + } + + @Test def goToDefinitionImportAlternatives: Unit = { + withSources( + code"""package a + class ${m1}Foo${m2} + object ${m3}Foo${m4}""", + code"""package b + import a.${m5}Foo${m6} + class C extends ${m7}Foo${m8} { + val bar = ${m9}Foo${m10} + }""" + ).definition(m1 to m2, List(m1 to m2)) + .definition(m3 to m4, List(m3 to m4)) + .definition(m5 to m6, List(m1 to m2, m3 to m4)) + .definition(m7 to m8, List(m1 to m2)) + .definition(m9 to m10, List(m3 to m4)) + } + + @Test def goToDefinitionImportAlternativesWithRename: Unit = { + withSources( + code"""package a + class ${m1}Foo${m2} + object ${m3}Foo${m4}""", + code"""package b + import a.{${m5}Foo${m6} => ${m7}Bar${m8}} + class C extends ${m9}Bar${m10} { + val buzz = ${m11}Bar${m12} + }""" + ).definition(m1 to m2, List(m1 to m2)) + .definition(m3 to m4, List(m3 to m4)) + .definition(m5 to m6, List(m1 to m2, m3 to m4)) + .definition(m7 to m8, List(m1 to m2, m3 to m4)) + .definition(m9 to m10, List(m1 to m2)) + .definition(m11 to m12, List(m3 to m4)) + } + + @Test def multipleImportsPerLineWithRename: Unit = { + withSources( + code"""object A { class ${m1}B${m2}; class ${m3}C${m4} }""", + code"""import A.{${m5}B${m6} => ${m7}B2${m8}, ${m9}C${m10} => ${m11}C2${m12}} + class E""" + ).definition(m1 to m2, List(m1 to m2)) + .definition(m3 to m4, List(m3 to m4)) + .definition(m5 to m6, List(m1 to m2)) + .definition(m7 to m8, List(m1 to m2)) + .definition(m9 to m10, List(m3 to m4)) + .definition(m11 to m12, List(m3 to m4)) + } + } diff --git a/language-server/test/dotty/tools/languageserver/HighlightTest.scala b/language-server/test/dotty/tools/languageserver/HighlightTest.scala index 3bf0957be1b8..8c8b0a445eda 100644 --- a/language-server/test/dotty/tools/languageserver/HighlightTest.scala +++ b/language-server/test/dotty/tools/languageserver/HighlightTest.scala @@ -25,4 +25,105 @@ class HighlightTest { .highlight(m3 to m4, (m1 to m2, DocumentHighlightKind.Read), (m3 to m4, DocumentHighlightKind.Read)) } + @Test def importHighlight0: Unit = { + code"""object ${m1}Foo${m2} { def ${m5}bar${m6}: Int = 0 } + trait Bar { import ${m3}Foo${m4}._; def buzz = ${m7}bar${m8} } + trait Baz { def ${m9}bar${m10}: Int = 1 }""".withSource + + .highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m3 to m4, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m1 to m2, DocumentHighlightKind.Read), (m3 to m4, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m7 to m8, (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m9 to m10, (m9 to m10, DocumentHighlightKind.Read)) + } + + @Test def importHighlight1: Unit = { + code"""import ${m1}Foo${m2}._ + object ${m3}Foo${m4} { def ${m5}bar${m6}: Int = 0 } + trait Bar { def buzz = ${m7}bar${m8} }""".withSource + + .highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m3 to m4, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m1 to m2, DocumentHighlightKind.Read), (m3 to m4, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m7 to m8, (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + } + + @Test def importHighlight2: Unit = { + code"""object ${m1}Foo${m2} { object ${m3}Bar${m4} { object ${m5}Baz${m6} } } + trait Buzz { import ${m7}Foo${m8}.${m9}Bar${m10}.${m11}Baz${m12} }""".withSource + + .highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m3 to m4, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m5 to m6, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + .highlight(m7 to m8, (m1 to m2, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m9 to m10, (m3 to m4, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read)) + .highlight(m11 to m12, (m5 to m6, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + } + + @Test def importHighlight3: Unit = { + code"""import ${m1}Foo${m2}.${m3}Bar${m4} + object ${m5}Foo${m6} { object ${m7}Bar${m8} }""".withSource + + .highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m3 to m4, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m1 to m2, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read)) + .highlight(m7 to m8, (m3 to m4, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + } + + @Test def importHighlightClassAndCompanion: Unit = { + code"""object Foo { object ${m1}Bar${m2}; class ${m3}Bar${m4} } + trait Buzz { import Foo.${m5}Bar${m6} }""".withSource + .highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m3 to m4, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m3 to m4, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m1 to m2, DocumentHighlightKind.Read)) + } + + @Test def importHighlightWithRename: Unit = { + code"""object ${m1}Foo${m2} { object ${m3}Bar${m4} { object ${m5}Baz${m6} } } + trait Buzz { import ${m7}Foo${m8}.${m9}Bar${m10}.{${m11}Baz${m12} => ${m13}Quux${m14}}""".withSource + + .highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m3 to m4, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m5 to m6, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read), (m13 to m14, DocumentHighlightKind.Read)) + .highlight(m7 to m8, (m1 to m2, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m9 to m10, (m3 to m4, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read)) + .highlight(m11 to m12, (m5 to m6, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read), (m13 to m14, DocumentHighlightKind.Read)) + .highlight(m13 to m14, (m5 to m6, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read), (m13 to m14, DocumentHighlightKind.Read)) + } + + @Test def importHighlightClassAndCompanionWithRename: Unit = { + code"""object ${m1}Foo${m2} { object ${m3}Bar${m4}; class ${m5}Bar${m6} } + trait Buzz { import ${m7}Foo${m8}.{${m9}Bar${m10} => ${m11}Baz${m12}} }""".withSource + + .highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m3 to m4, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m5 to m6, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + .highlight(m7 to m8, (m1 to m2, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m9 to m10, (m3 to m4, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + .highlight(m11 to m12, (m3 to m4, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + } + + @Test def importHighlightMembers: Unit = { + code"""object Foo { def ${m1}bar${m2} = 2; type ${m3}bar${m4} = fizz; class fizz } + trait Quux { import Foo.{${m5}bar${m6} => ${m7}buzz${m8}} }""".withSource + + .highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m3 to m4, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m1 to m2, DocumentHighlightKind.Read), (m3 to m4, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m7 to m8, (m1 to m2, DocumentHighlightKind.Read), (m3 to m4, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + } + + @Test def multipleImportsPerLineWithRename: Unit = { + withSources( + code"""object A { class ${m1}B${m2}; class ${m3}C${m4} } + import A.{${m5}B${m6} => ${m7}B2${m8}, ${m9}C${m10} => ${m11}C2${m12}} + class E""" + ).highlight(m1 to m2, (m1 to m2, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m3 to m4, (m3 to m4, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + .highlight(m5 to m6, (m1 to m2, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m7 to m8, (m1 to m2, DocumentHighlightKind.Read), (m5 to m6, DocumentHighlightKind.Read), (m7 to m8, DocumentHighlightKind.Read)) + .highlight(m9 to m10, (m3 to m4, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + .highlight(m11 to m12, (m3 to m4, DocumentHighlightKind.Read), (m9 to m10, DocumentHighlightKind.Read), (m11 to m12, DocumentHighlightKind.Read)) + } + } diff --git a/language-server/test/dotty/tools/languageserver/ReferencesTest.scala b/language-server/test/dotty/tools/languageserver/ReferencesTest.scala index 9b2425582157..bfe811b2751e 100644 --- a/language-server/test/dotty/tools/languageserver/ReferencesTest.scala +++ b/language-server/test/dotty/tools/languageserver/ReferencesTest.scala @@ -217,4 +217,133 @@ class ReferencesTest { .references(m1 to m2, List(m3 to m4), withDecl = false) } + @Test def importReference1: Unit = { + code"""import ${m1}Foo${m2}._ + object ${m3}Foo${m4} { def ${m5}bar${m6}: Int = 0 } + trait Bar { def buzz = ${m7}bar${m8} }""".withSource + + .references(m1 to m2, List(m1 to m2, m3 to m4), withDecl = true) + .references(m1 to m2, List(m1 to m2), withDecl = false) + .references(m3 to m4, List(m1 to m2, m3 to m4), withDecl = true) + .references(m3 to m4, List(m1 to m2), withDecl = false) + .references(m5 to m6, List(m5 to m6, m7 to m8), withDecl = true) + .references(m5 to m6, List(m7 to m8), withDecl = false) + .references(m7 to m8, List(m5 to m6, m7 to m8), withDecl = true) + .references(m7 to m8, List(m7 to m8), withDecl = false) + } + + @Test def importReference2: Unit = { + code"""object ${m1}Foo${m2} { object ${m3}Bar${m4} { object ${m5}Baz${m6} } } + trait Buzz { import ${m7}Foo${m8}.${m9}Bar${m10}.${m11}Baz${m12} }""".withSource + + .references(m1 to m2, List(m1 to m2, m7 to m8), withDecl = true) + .references(m1 to m2, List(m7 to m8), withDecl = false) + .references(m3 to m4, List(m3 to m4, m9 to m10), withDecl = true) + .references(m3 to m4, List(m9 to m10), withDecl = false) + .references(m5 to m6, List(m5 to m6, m11 to m12), withDecl = true) + .references(m5 to m6, List(m11 to m12), withDecl = false) + .references(m7 to m8, List(m1 to m2, m7 to m8), withDecl = true) + .references(m7 to m8, List(m7 to m8), withDecl = false) + .references(m9 to m10, List(m3 to m4, m9 to m10), withDecl = true) + .references(m9 to m10, List(m9 to m10), withDecl = false) + .references(m11 to m12, List(m5 to m6, m11 to m12), withDecl = true) + .references(m11 to m12, List(m11 to m12), withDecl = false) + } + + @Test def importReference3: Unit = { + code"""import ${m1}Foo${m2}.${m3}Bar${m4} + object ${m5}Foo${m6} { object ${m7}Bar${m8} }""".withSource + + .references(m1 to m2, List(m1 to m2, m5 to m6), withDecl = true) + .references(m1 to m2, List(m1 to m2), withDecl = false) + .references(m3 to m4, List(m3 to m4, m7 to m8), withDecl = true) + .references(m3 to m4, List(m3 to m4), withDecl = false) + .references(m5 to m6, List(m1 to m2, m5 to m6), withDecl = true) + .references(m5 to m6, List(m1 to m2), withDecl = false) + .references(m7 to m8, List(m3 to m4, m7 to m8), withDecl = true) + .references(m7 to m8, List(m3 to m4), withDecl = false) + } + + @Test def importReferenceClassAndCompanion: Unit = { + code"""object Foo { object ${m1}Bar${m2}; class ${m3}Bar${m4} } + trait Buzz { import Foo.${m5}Bar${m6} }""".withSource + .references(m1 to m2, List(m1 to m2, m5 to m6), withDecl = true) + .references(m1 to m2, List(m5 to m6), withDecl = false) + .references(m3 to m4, List(m3 to m4, m5 to m6), withDecl = true) + .references(m3 to m4, List(m5 to m6), withDecl = false) + .references(m5 to m6, List(m1 to m2, m3 to m4, m5 to m6), withDecl = true) + .references(m5 to m6, List(m5 to m6), withDecl = false) + } + + @Test def importReferenceWithRename: Unit = { + code"""object ${m1}Foo${m2} { object ${m3}Bar${m4} { object ${m5}Baz${m6} } } + trait Buzz { import ${m7}Foo${m8}.${m9}Bar${m10}.{${m11}Baz${m12} => ${m13}Quux${m14}}""".withSource + + .references(m1 to m2, List(m1 to m2, m7 to m8), withDecl = true) + .references(m1 to m2, List(m7 to m8), withDecl = false) + .references(m3 to m4, List(m3 to m4, m9 to m10), withDecl = true) + .references(m3 to m4, List(m9 to m10), withDecl = false) + .references(m5 to m6, List(m5 to m6, m11 to m12, m13 to m14), withDecl = true) + .references(m5 to m6, List(m11 to m12, m13 to m14), withDecl = false) + .references(m7 to m8, List(m1 to m2, m7 to m8), withDecl = true) + .references(m7 to m8, List(m7 to m8), withDecl = false) + .references(m9 to m10, List(m3 to m4, m9 to m10), withDecl = true) + .references(m9 to m10, List(m9 to m10), withDecl = false) + .references(m11 to m12, List(m5 to m6, m11 to m12, m13 to m14), withDecl = true) + .references(m11 to m12, List(m11 to m12, m13 to m14), withDecl = false) + .references(m13 to m14, List(m5 to m6, m11 to m12, m13 to m14), withDecl = true) + .references(m13 to m14, List(m11 to m12, m13 to m14), withDecl = false) + } + + @Test def importReferenceClassAndCompanionWithRename: Unit = { + code"""object ${m1}Foo${m2} { object ${m3}Bar${m4}; class ${m5}Bar${m6} } + trait Buzz { import ${m7}Foo${m8}.{${m9}Bar${m10} => ${m11}Baz${m12}} }""".withSource + + .references(m1 to m2, List(m1 to m2, m7 to m8), withDecl = true) + .references(m1 to m2, List(m7 to m8), withDecl = false) + .references(m3 to m4, List(m3 to m4, m9 to m10, m11 to m12), withDecl = true) + .references(m3 to m4, List(m9 to m10, m11 to m12), withDecl = false) + .references(m5 to m6, List(m5 to m6, m9 to m10, m11 to m12), withDecl = true) + .references(m5 to m6, List(m9 to m10, m11 to m12), withDecl = false) + .references(m7 to m8, List(m1 to m2, m7 to m8), withDecl = true) + .references(m7 to m8, List(m7 to m8), withDecl = false) + .references(m9 to m10, List(m3 to m4, m5 to m6, m9 to m10, m11 to m12), withDecl = true) + .references(m9 to m10, List(m9 to m10, m11 to m12), withDecl = false) + .references(m11 to m12, List(m3 to m4, m5 to m6, m9 to m10, m11 to m12), withDecl = true) + .references(m11 to m12, List(m9 to m10, m11 to m12), withDecl = false) + } + + @Test def importReferenceMembers: Unit = { + code"""object Foo { def ${m1}bar${m2} = 2; type ${m3}bar${m4} = fizz; class fizz } + trait Quux { import Foo.{${m5}bar${m6} => ${m7}buzz${m8}} }""".withSource + + .references(m1 to m2, List(m1 to m2, m5 to m6, m7 to m8), withDecl = true) + .references(m1 to m2, List(m5 to m6, m7 to m8), withDecl = false) + .references(m3 to m4, List(m3 to m4, m5 to m6, m7 to m8), withDecl = true) + .references(m3 to m4, List(m5 to m6, m7 to m8), withDecl = false) + .references(m5 to m6, List(m1 to m2, m3 to m4, m5 to m6, m7 to m8), withDecl = true) + .references(m5 to m6, List(m5 to m6, m7 to m8), withDecl = false) + .references(m7 to m8, List(m1 to m2, m3 to m4, m5 to m6, m7 to m8), withDecl = true) + .references(m7 to m8, List(m5 to m6, m7 to m8), withDecl = false) + } + + @Test def multipleImportsPerLineWithRename: Unit = { + withSources( + code"""object A { class ${m1}B${m2}; class ${m3}C${m4} }""", + code"""import A.{${m5}B${m6} => ${m7}B2${m8}, ${m9}C${m10} => ${m11}C2${m12}} + class E""" + ).references(m1 to m2, List(m1 to m2, m5 to m6, m7 to m8), withDecl = true) + .references(m1 to m2, List(m5 to m6, m7 to m8), withDecl = false) + .references(m3 to m4, List(m3 to m4, m9 to m10, m11 to m12), withDecl = true) + .references(m3 to m4, List(m9 to m10, m11 to m12), withDecl = false) + .references(m5 to m6, List(m1 to m2, m5 to m6, m7 to m8), withDecl = true) + .references(m5 to m6, List(m5 to m6, m7 to m8), withDecl = false) + .references(m7 to m8, List(m1 to m2, m5 to m6, m7 to m8), withDecl = true) + .references(m7 to m8, List(m5 to m6, m7 to m8), withDecl = false) + .references(m9 to m10, List(m3 to m4, m9 to m10, m11 to m12), withDecl = true) + .references(m9 to m10, List(m9 to m10, m11 to m12), withDecl = false) + .references(m11 to m12, List(m3 to m4, m9 to m10, m11 to m12), withDecl = true) + .references(m11 to m12, List(m9 to m10, m11 to m12), withDecl = false) + } + } diff --git a/language-server/test/dotty/tools/languageserver/RenameTest.scala b/language-server/test/dotty/tools/languageserver/RenameTest.scala index 199778b7352a..be1dc06fec2c 100644 --- a/language-server/test/dotty/tools/languageserver/RenameTest.scala +++ b/language-server/test/dotty/tools/languageserver/RenameTest.scala @@ -3,6 +3,7 @@ package dotty.tools.languageserver import org.junit.Test import dotty.tools.languageserver.util.Code._ +import dotty.tools.languageserver.util.CodeRange import dotty.tools.languageserver.util.embedded.CodeMarker class RenameTest { @@ -70,6 +71,171 @@ class RenameTest { testRenameFrom(m1) testRenameFrom(m2) + testRenameFrom(m3) + testRenameFrom(m4) } + @Test def renameImport: Unit = { + def testRenameFrom(m: CodeMarker) = + withSources( + code"""object A { class ${m1}C${m2} }""", + code"""import A.${m3}C${m4} + object B""" + ).rename(m, "NewName", Set(m1 to m2, m3 to m4)) + + testRenameFrom(m1) + testRenameFrom(m2) + testRenameFrom(m3) + testRenameFrom(m4) + } + + @Test def renameRenamedImport: Unit = { + def sources = + withSources( + code"""object A { class ${m1}C${m2} }""", + code"""import A.{${m3}C${m4} => ${m5}D${m6}} + object B { new ${m7}D${m8} }""" + ) + def testRename(m: CodeMarker, expectations: Set[CodeRange]) = + sources.rename(m, "NewName", expectations) + + testRename(m1, Set(m1 to m2, m3 to m4)) + testRename(m2, Set(m1 to m2, m3 to m4)) + testRename(m3, Set(m1 to m2, m3 to m4)) + testRename(m4, Set(m1 to m2, m3 to m4)) + testRename(m5, Set(m5 to m6, m7 to m8)) + testRename(m6, Set(m5 to m6, m7 to m8)) + testRename(m7, Set(m5 to m6, m7 to m8)) + testRename(m8, Set(m5 to m6, m7 to m8)) + } + + @Test def renameRenamingImport: Unit = { + def sources = + withSources( + code"""object A { class ${m1}C${m2}; object ${m3}C${m4} }""", + code"""object O1 { + import A.{${m5}C${m6} => ${m7}Renamed${m8}} + class C2 extends ${m9}Renamed${m10} { val x = ${m11}Renamed${m12} } + } + object O2 { + import A.{${m13}C${m14} => ${m15}Renamed${m16}} + class C3 extends ${m17}Renamed${m18} { val x = ${m19}Renamed${m20} } + }""" + ) + def testRename(m: CodeMarker, expectations: Set[CodeRange]) = + sources.rename(m, "NewName", expectations) + + testRename(m1, Set(m1 to m2, m3 to m4, m5 to m6, m13 to m14)) + testRename(m2, Set(m1 to m2, m3 to m4, m5 to m6, m13 to m14)) + testRename(m3, Set(m1 to m2, m3 to m4, m5 to m6, m13 to m14)) + testRename(m4, Set(m1 to m2, m3 to m4, m5 to m6, m13 to m14)) + testRename(m5, Set(m1 to m2, m3 to m4, m5 to m6, m13 to m14)) + testRename(m6, Set(m1 to m2, m3 to m4, m5 to m6, m13 to m14)) + + testRename(m7, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m8, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m9, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m10, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m11, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m12, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + + testRename(m13, Set(m1 to m2, m3 to m4, m5 to m6, m13 to m14)) + testRename(m14, Set(m1 to m2, m3 to m4, m5 to m6, m13 to m14)) + + testRename(m15, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m16, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m17, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m18, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m19, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + testRename(m20, Set(m7 to m8, m9 to m10, m11 to m12, m15 to m16, m17 to m18, m19 to m20)) + + } + + @Test def renameRenamingImportNested: Unit = { + def sources = + withSources( + code"""object A { class C }""", + code"""import A.{C => ${m1}Renamed${m2}} + object O { + import A.{C => ${m3}Renamed${m4}} + class C2 extends ${m5}Renamed${m6} { self: ${m15}Renamed${m16} => + import A.{C => ${m7}Renamed${m8}} + } + 123 match { + case x if new ${m9}Renamed${m10} == null => ??? + case foo if { + import A.{C => ${m11}Renamed${m12}} + new ${m13}Renamed${m14} != null + } => ??? + } + new A.C + }""" + ) + def testRename(m: CodeMarker, expectations: Set[CodeRange]) = + sources.rename(m, "NewName", expectations) + + testRename(m1, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m2, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m3, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m4, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m5, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m6, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m7, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m8, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m9, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m10, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m11, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m12, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m13, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m14, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m15, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + testRename(m16, Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8, m9 to m10, m11 to m12, m13 to m14, m15 to m16)) + } + + @Test def renameImportWithRenaming: Unit = { + def testRename(m: CodeMarker) = + withSources( + code"""object A { class ${m1}C${m2} }""", + code"""import A.${m3}C${m4} + object O { + class B extends ${m5}C${m6} { + import A.{${m7}C${m8} => Renamed} + def foo = new Renamed + } + }""" + ).rename(m, "NewName", Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8)) + + testRename(m1) + testRename(m2) + testRename(m3) + testRename(m4) + testRename(m5) + testRename(m6) + testRename(m7) + testRename(m8) + } + + @Test def renameOverridden: Unit = { + def testRename(m: CodeMarker, expectations: Set[CodeRange], withOverridden: Option[Boolean]) = + withSources( + code"""class A { def ${m1}foo${m2}: Int = 0 } + class B extends A { override def ${m3}foo${m4}: Int = 1 } + class C extends A { override def ${m5}foo${m6}: Int = 2 }""" + ).rename(m, "NewName", expectations, withOverridden) + + testRename(m1, Set(m1 to m2, m3 to m4, m5 to m6), withOverridden = None) + testRename(m2, Set(m1 to m2, m3 to m4, m5 to m6), withOverridden = None) + testRename(m3, Set(m1 to m2, m3 to m4, m5 to m6), withOverridden = Some(true)) + testRename(m4, Set(m1 to m2, m3 to m4, m5 to m6), withOverridden = Some(true)) + testRename(m5, Set(m1 to m2, m3 to m4, m5 to m6), withOverridden = Some(true)) + testRename(m6, Set(m1 to m2, m3 to m4, m5 to m6), withOverridden = Some(true)) + testRename(m3, Set(m3 to m4), withOverridden = Some(false)) + testRename(m4, Set(m3 to m4), withOverridden = Some(false)) + testRename(m5, Set(m5 to m6), withOverridden = Some(false)) + testRename(m6, Set(m5 to m6), withOverridden = Some(false)) + + } + + + } diff --git a/language-server/test/dotty/tools/languageserver/util/Code.scala b/language-server/test/dotty/tools/languageserver/util/Code.scala index 81edf2c322d9..c5a742d21fa5 100644 --- a/language-server/test/dotty/tools/languageserver/util/Code.scala +++ b/language-server/test/dotty/tools/languageserver/util/Code.scala @@ -30,6 +30,8 @@ object Code { val m16 = new CodeMarker("m16") val m17 = new CodeMarker("m17") val m18 = new CodeMarker("m18") + val m19 = new CodeMarker("m19") + val m20 = new CodeMarker("m20") implicit class CodeHelper(val sc: StringContext) extends AnyVal { diff --git a/language-server/test/dotty/tools/languageserver/util/CodeTester.scala b/language-server/test/dotty/tools/languageserver/util/CodeTester.scala index 8614e13bf721..c808ff6b8d05 100644 --- a/language-server/test/dotty/tools/languageserver/util/CodeTester.scala +++ b/language-server/test/dotty/tools/languageserver/util/CodeTester.scala @@ -92,14 +92,20 @@ class CodeTester(projects: List[Project]) { * Performs a workspace-wide renaming of the symbol under `marker`, verifies that the positions to * update match `expected`. * - * @param marker The position from which to ask for renaming. - * @param newName The new name to give to the symbol. - * @param expected The expected positions to change. + * @param marker The position from which to ask for renaming. + * @param newName The new name to give to the symbol. + * @param expected The expected positions to change. + * @param withOverridden If `None`, do not expect the server to ask whether to include overridden + * symbol. Otherwise, wait for this question from the server and include + * overridden symbols if this is true. * * @see dotty.tools.languageserver.util.actions.CodeRename */ - def rename(marker: CodeMarker, newName: String, expected: Set[CodeRange]): this.type = - doAction(new CodeRename(marker, newName, expected)) // TODO apply changes to the sources and positions + def rename(marker: CodeMarker, + newName: String, + expected: Set[CodeRange], + withOverridden: Option[Boolean] = None): this.type = + doAction(new CodeRename(marker, newName, expected, withOverridden)) // TODO apply changes to the sources and positions /** * Queries for all the symbols referenced in the source file in `marker`, verifies that they match diff --git a/language-server/test/dotty/tools/languageserver/util/actions/Action.scala b/language-server/test/dotty/tools/languageserver/util/actions/Action.scala index f4654c763ca7..746d9c4c01d4 100644 --- a/language-server/test/dotty/tools/languageserver/util/actions/Action.scala +++ b/language-server/test/dotty/tools/languageserver/util/actions/Action.scala @@ -25,4 +25,12 @@ trait Action { /** The client that executes this action. */ def client: Exec[TestClient] = implicitly[TestClient] + /** An ordering for `Location` that compares string representations. */ + implicit def locationOrdering: Ordering[org.eclipse.lsp4j.Location] = + Ordering.by(_.toString) + + /** An ordering for `Range` that compares string representations. */ + implicit def rangeOrdering: Ordering[org.eclipse.lsp4j.Range] = + Ordering.by(_.toString) + } diff --git a/language-server/test/dotty/tools/languageserver/util/actions/CodeDefinition.scala b/language-server/test/dotty/tools/languageserver/util/actions/CodeDefinition.scala index 0e1e41774ac2..0ece0d9a5447 100644 --- a/language-server/test/dotty/tools/languageserver/util/actions/CodeDefinition.scala +++ b/language-server/test/dotty/tools/languageserver/util/actions/CodeDefinition.scala @@ -17,8 +17,8 @@ import org.junit.Assert.assertEquals class CodeDefinition(override val range: CodeRange, expected: Seq[CodeRange]) extends ActionOnRange { override def onMarker(marker: CodeMarker): Exec[Unit] = { - val results = server.definition(marker.toTextDocumentPositionParams).get().asScala.toSeq - val expectedLocations = expected.map(_.toLocation) + val results = server.definition(marker.toTextDocumentPositionParams).get().asScala.toSeq.sorted + val expectedLocations = expected.map(_.toLocation).sorted assertEquals(expectedLocations, results) } diff --git a/language-server/test/dotty/tools/languageserver/util/actions/CodeDocumentHighlight.scala b/language-server/test/dotty/tools/languageserver/util/actions/CodeDocumentHighlight.scala index eefbac6dd40f..58cb288716cc 100644 --- a/language-server/test/dotty/tools/languageserver/util/actions/CodeDocumentHighlight.scala +++ b/language-server/test/dotty/tools/languageserver/util/actions/CodeDocumentHighlight.scala @@ -20,9 +20,9 @@ class CodeDocumentHighlight(override val range: CodeRange, expected: Seq[(CodeRange, DocumentHighlightKind)]) extends ActionOnRange { override def onMarker(marker: CodeMarker): Exec[Unit] = { - val expectedPairs = expected.map { case (codeRange, kind) => (codeRange.toRange, kind) } + val expectedPairs = expected.map { case (codeRange, kind) => (codeRange.toRange, kind) }.sorted val results = server.documentHighlight(marker.toTextDocumentPositionParams).get() - val resultPairs = results.asScala.map { result => (result.getRange, result.getKind) } + val resultPairs = results.asScala.map { result => (result.getRange, result.getKind) }.sorted assertEquals(expectedPairs, resultPairs) } diff --git a/language-server/test/dotty/tools/languageserver/util/actions/CodeRename.scala b/language-server/test/dotty/tools/languageserver/util/actions/CodeRename.scala index 5e69d00b4bbb..331e8529e2fd 100644 --- a/language-server/test/dotty/tools/languageserver/util/actions/CodeRename.scala +++ b/language-server/test/dotty/tools/languageserver/util/actions/CodeRename.scala @@ -2,8 +2,13 @@ package dotty.tools.languageserver.util.actions import dotty.tools.languageserver.util.embedded.CodeMarker import dotty.tools.languageserver.util.{CodeRange, PositionContext} +import dotty.tools.languageserver.DottyLanguageServer.{RENAME_OVERRIDDEN, RENAME_NO_OVERRIDDEN} -import org.junit.Assert.{assertEquals, assertNull} +import org.junit.Assert.{assertEquals, assertNull, fail} + +import org.eclipse.lsp4j.{MessageActionItem, ShowMessageRequestParams} + +import java.util.concurrent.CompletableFuture import scala.collection.JavaConverters._ @@ -17,10 +22,31 @@ import scala.collection.JavaConverters._ */ class CodeRename(override val marker: CodeMarker, newName: String, - expected: Set[CodeRange]) extends ActionOnMarker { + expected: Set[CodeRange], + withOverridden: Option[Boolean]) extends ActionOnMarker { + + private final val TIMEOUT_MS = 10000 override def execute(): Exec[Unit] = { - val results = server.rename(marker.toRenameParams(newName)).get() + val query = server.rename(marker.toRenameParams(newName)) + + withOverridden.foreach { includeOverridden => + var question: (ShowMessageRequestParams, CompletableFuture[MessageActionItem]) = null + val startTime = System.currentTimeMillis() + do { + Thread.sleep(50) + question = client.requests.get.headOption.orNull + } while (question == null && System.currentTimeMillis() - startTime < TIMEOUT_MS) + + if (question == null) fail("The server didn't ask about overridden symbols.") + + val answerStr = if (includeOverridden) RENAME_OVERRIDDEN else RENAME_NO_OVERRIDDEN + val action = question._1.getActions.asScala.find(_.getTitle == answerStr).get + question._2.complete(action) + } + + val results = query.get() + val changes = results.getChanges.asScala.mapValues(_.asScala.toSet.map(ch => (ch.getNewText, ch.getRange))) val expectedChanges = expected.groupBy(_.file.uri).mapValues(_.map(range => (newName, range.toRange))) diff --git a/language-server/test/dotty/tools/languageserver/util/server/TestClient.scala b/language-server/test/dotty/tools/languageserver/util/server/TestClient.scala index 5d2effffa26e..bba69c5e2b9a 100644 --- a/language-server/test/dotty/tools/languageserver/util/server/TestClient.scala +++ b/language-server/test/dotty/tools/languageserver/util/server/TestClient.scala @@ -23,6 +23,7 @@ class TestClient extends WorksheetClient { val diagnostics = new Log[PublishDiagnosticsParams] val telemetry = new Log[Any] val worksheetOutput = new Log[WorksheetRunOutput] + val requests = new Log[(ShowMessageRequestParams, CompletableFuture[MessageActionItem])] override def logMessage(message: MessageParams) = { log += message @@ -37,8 +38,9 @@ class TestClient extends WorksheetClient { } override def showMessageRequest(requestParams: ShowMessageRequestParams) = { - log += requestParams - new CompletableFuture[MessageActionItem] + val reply = new CompletableFuture[MessageActionItem] + requests += ((requestParams, reply)) + reply } override def publishDiagnostics(diagnosticsParams: PublishDiagnosticsParams) = {