Skip to content

Commit c9cae24

Browse files
authored
Merge pull request #5047 from dotty-staging/fix/4995
Fix #4995: Exclude constructors when finding refs
2 parents 7a8dd59 + 31935f9 commit c9cae24

File tree

4 files changed

+115
-13
lines changed

4 files changed

+115
-13
lines changed

compiler/src/dotty/tools/dotc/interactive/Interactive.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ object Interactive {
2525
type Set = Int
2626
val overridden = 1 // include trees whose symbol is overridden by `sym`
2727
val overriding = 2 // include trees whose symbol overrides `sym` (but for performance only in same source file)
28-
val references = 4 // include references and not just definitions
28+
val references = 4 // include references
29+
val definitions = 8 // include definitions
30+
val linkedClass = 16 // include `symbol.linkedClass`
2931
}
3032

3133
/** Does this tree define a symbol ? */
@@ -304,6 +306,35 @@ object Interactive {
304306
buf.toList
305307
}
306308

309+
/**
310+
* Find trees that match `symbol` in `trees`.
311+
*
312+
* @param trees The trees to inspect.
313+
* @param includes Whether to include references, definitions, etc.
314+
* @param symbol The symbol for which we want to find references.
315+
*/
316+
def findTreesMatching(trees: List[SourceTree],
317+
includes: Include.Set,
318+
symbol: Symbol)(implicit ctx: Context): List[SourceTree] = {
319+
val linkedSym = symbol.linkedClass
320+
val includeReferences = (includes & Include.references) != 0
321+
val includeDeclaration = (includes & Include.definitions) != 0
322+
val includeLinkedClass = (includes & Include.linkedClass) != 0
323+
val predicate: NameTree => Boolean = tree =>
324+
( tree.pos.isSourceDerived
325+
&& !tree.symbol.isConstructor
326+
&& (includeDeclaration || !Interactive.isDefinition(tree))
327+
&& ( Interactive.matchSymbol(tree, symbol, includes)
328+
|| ( includeDeclaration
329+
&& includeLinkedClass
330+
&& linkedSym.exists
331+
&& Interactive.matchSymbol(tree, linkedSym, includes)
332+
)
333+
)
334+
)
335+
namedTrees(trees, includeReferences, predicate)
336+
}
337+
307338
/** The reverse path to the node that closest encloses position `pos`,
308339
* or `Nil` if no such path exists. If a non-empty path is returned it starts with
309340
* the tree closest enclosing `pos` and ends with an element of `trees`.

language-server/src/dotty/tools/languageserver/DottyLanguageServer.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,6 @@ class DottyLanguageServer extends LanguageServer
296296
val driver = driverFor(uri)
297297
implicit val ctx = driver.currentCtx
298298

299-
val includeDeclaration = params.getContext.isIncludeDeclaration
300299
val pos = sourcePosition(driver, uri, params.getPosition)
301300
val sym = Interactive.enclosingSourceSymbol(driver.openedTrees(uri), pos)
302301

@@ -306,9 +305,10 @@ class DottyLanguageServer extends LanguageServer
306305
// only need to look for trees in the target directory if the symbol is defined in the
307306
// current project
308307
val trees = driver.allTreesContaining(sym.name.sourceModuleName.toString)
309-
val refs = Interactive.namedTrees(trees, includeReferences = true, (tree: tpd.NameTree) =>
310-
(includeDeclaration || !Interactive.isDefinition(tree))
311-
&& Interactive.matchSymbol(tree, sym, Include.overriding))
308+
val includeDeclaration = params.getContext.isIncludeDeclaration
309+
val includes =
310+
Include.references | Include.overriding | (if (includeDeclaration) Include.definitions else 0)
311+
val refs = Interactive.findTreesMatching(trees, includes, sym)
312312

313313
refs.map(ref => location(ref.namePos)).asJava
314314
}
@@ -325,13 +325,10 @@ class DottyLanguageServer extends LanguageServer
325325
if (sym == NoSymbol) new WorkspaceEdit()
326326
else {
327327
val trees = driver.allTreesContaining(sym.name.sourceModuleName.toString)
328-
val linkedSym = sym.linkedClass
329328
val newName = params.getNewName
330-
331-
val refs = Interactive.namedTrees(trees, includeReferences = true, tree =>
332-
tree.pos.isSourceDerived
333-
&& (Interactive.matchSymbol(tree, sym, Include.overriding)
334-
|| (linkedSym != NoSymbol && Interactive.matchSymbol(tree, linkedSym, Include.overriding))))
329+
val includes =
330+
Include.references | Include.definitions | Include.linkedClass | Include.overriding
331+
val refs = Interactive.findTreesMatching(trees, includes, sym)
335332

336333
val changes = refs.groupBy(ref => toUri(ref.source).toString).mapValues(_.map(ref => new TextEdit(range(ref.namePos), newName)).asJava)
337334

language-server/test/dotty/tools/languageserver/ReferencesTest.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,33 @@ class ReferencesTest {
1919
.references(m1 to m2, List(m1 to m2, m3 to m4, m5 to m6), withDecl = true)
2020
}
2121

22+
@Test def classReference0: Unit = {
23+
code"class ${m1}Foo${m2} { val a = new ${m3}Foo${m4} }".withSource
24+
.references(m1 to m2, List(m1 to m2, m3 to m4), withDecl = true)
25+
.references(m1 to m2, List(m3 to m4), withDecl = false)
26+
.references(m3 to m4, List(m1 to m2, m3 to m4), withDecl = true)
27+
.references(m3 to m4, List(m3 to m4), withDecl = false)
28+
}
29+
30+
@Test def classReference1: Unit = {
31+
code"class ${m1}Foo${m2}(x: Int) { val a = new ${m3}Foo${m4}(1) }".withSource
32+
.references(m1 to m2, List(m1 to m2, m3 to m4), withDecl = true)
33+
.references(m1 to m2, List(m3 to m4), withDecl = false)
34+
.references(m3 to m4, List(m1 to m2, m3 to m4), withDecl = true)
35+
.references(m3 to m4, List(m3 to m4), withDecl = false)
36+
}
37+
38+
@Test def classReferenceCompanion: Unit = {
39+
code"""class ${m1}Foo${m2}(x: Any)
40+
object ${m3}Foo${m4} { val bar = new ${m5}Foo${m6}(${m7}Foo${m8}) }""".withSource
41+
.references(m1 to m2, List(m1 to m2, m5 to m6), withDecl = true)
42+
.references(m1 to m2, List(m5 to m6), withDecl = false)
43+
.references(m3 to m4, List(m3 to m4, m7 to m8), withDecl = true)
44+
.references(m3 to m4, List(m7 to m8), withDecl = false)
45+
.references(m5 to m6, List(m1 to m2, m5 to m6), withDecl = true)
46+
.references(m5 to m6, List(m5 to m6), withDecl = false)
47+
.references(m7 to m8, List(m3 to m4, m7 to m8), withDecl = true)
48+
.references(m7 to m8, List(m7 to m8), withDecl = false)
49+
}
50+
2251
}

language-server/test/dotty/tools/languageserver/RenameTest.scala

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,57 @@ class RenameTest {
1919
def testRenameFrom(m: CodeMarker) =
2020
withSources(
2121
code"class ${m1}Foo$m2 { new ${m3}Foo$m4 }",
22-
code"class Bar { new ${m5}Foo$m6 }"
23-
).rename(m, "Bar", Set(m1 to m2, m3 to m4, m5 to m6))
22+
code"class Bar { new ${m5}Foo$m6 }",
23+
code"class Baz extends ${m7}Foo${m8}"
24+
).rename(m, "Bar", Set(m1 to m2, m3 to m4, m5 to m6, m7 to m8))
2425

2526
testRenameFrom(m1)
2627
testRenameFrom(m3)
2728
testRenameFrom(m5)
2829
}
2930

31+
@Test def renameObject: Unit = {
32+
def testRenameFrom(m: CodeMarker) =
33+
withSources(
34+
code"object ${m1}Foo${m2}",
35+
code"class Bar { val x = ${m3}Foo${m4} }"
36+
).rename(m, "NewName", Set(m1 to m2, m3 to m4))
37+
38+
testRenameFrom(m1)
39+
testRenameFrom(m2)
40+
}
41+
42+
@Test def renameDef: Unit = {
43+
def testRenameFrom(m: CodeMarker) =
44+
withSources(
45+
code"object Foo { def ${m1}bar${m2} = 0 }",
46+
code"object Buzz { Foo.${m3}bar${m4} }"
47+
).rename(m, "newName", Set(m1 to m2, m3 to m4))
48+
49+
testRenameFrom(m1)
50+
testRenameFrom(m3)
51+
}
52+
53+
@Test def renameClass: Unit = {
54+
def testRenameFrom(m: CodeMarker) =
55+
withSources(
56+
code"class ${m1}Foo${m2}(x: Int)",
57+
code"class Bar extends ${m3}Foo${m4}(1)"
58+
).rename(m, "NewName", Set(m1 to m2, m3 to m4))
59+
60+
testRenameFrom(m1)
61+
testRenameFrom(m3)
62+
}
63+
64+
@Test def renameCaseClass: Unit = {
65+
def testRenameFrom(m: CodeMarker) =
66+
withSources(
67+
code"case class ${m1}Foo${m2}(x: Int)",
68+
code"class Bar extends ${m3}Foo${m4}(1)"
69+
).rename(m, "NewName", Set(m1 to m2, m3 to m4))
70+
71+
testRenameFrom(m1)
72+
testRenameFrom(m2)
73+
}
74+
3075
}

0 commit comments

Comments
 (0)