Skip to content

Commit c10275d

Browse files
Backport "Support completions for extension definition parameter" to LTS (#20688)
Backports #18331 to the LTS branch. PR submitted by the release tooling. [skip ci]
2 parents e2a9516 + e8a8428 commit c10275d

File tree

9 files changed

+421
-209
lines changed

9 files changed

+421
-209
lines changed

compiler/src/dotty/tools/dotc/interactive/Completion.scala

Lines changed: 83 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package dotty.tools.dotc.interactive
22

3-
import scala.language.unsafeNulls
4-
53
import dotty.tools.dotc.ast.untpd
4+
import dotty.tools.dotc.ast.NavigateAST
65
import dotty.tools.dotc.config.Printers.interactiv
76
import dotty.tools.dotc.core.Contexts._
87
import dotty.tools.dotc.core.Decorators._
@@ -25,6 +24,10 @@ import dotty.tools.dotc.util.SourcePosition
2524

2625
import scala.collection.mutable
2726
import scala.util.control.NonFatal
27+
import dotty.tools.dotc.core.ContextOps.localContext
28+
import dotty.tools.dotc.core.Names
29+
import dotty.tools.dotc.core.Types
30+
import dotty.tools.dotc.core.Symbols
2831

2932
/**
3033
* One of the results of a completion query.
@@ -37,18 +40,17 @@ import scala.util.control.NonFatal
3740
*/
3841
case class Completion(label: String, description: String, symbols: List[Symbol])
3942

40-
object Completion {
43+
object Completion:
4144

4245
import dotty.tools.dotc.ast.tpd._
4346

4447
/** Get possible completions from tree at `pos`
4548
*
4649
* @return offset and list of symbols for possible completions
4750
*/
48-
def completions(pos: SourcePosition)(using Context): (Int, List[Completion]) = {
49-
val path = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
51+
def completions(pos: SourcePosition)(using Context): (Int, List[Completion]) =
52+
val path: List[Tree] = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
5053
computeCompletions(pos, path)(using Interactive.contextOfPath(path).withPhase(Phases.typerPhase))
51-
}
5254

5355
/**
5456
* Inspect `path` to determine what kinds of symbols should be considered.
@@ -60,10 +62,11 @@ object Completion {
6062
*
6163
* Otherwise, provide no completion suggestion.
6264
*/
63-
def completionMode(path: List[Tree], pos: SourcePosition): Mode =
64-
path match {
65-
case Ident(_) :: Import(_, _) :: _ => Mode.ImportOrExport
66-
case (ref: RefTree) :: _ =>
65+
def completionMode(path: List[untpd.Tree], pos: SourcePosition): Mode =
66+
path match
67+
case untpd.Ident(_) :: untpd.Import(_, _) :: _ => Mode.ImportOrExport
68+
case untpd.Ident(_) :: (_: untpd.ImportSelector) :: _ => Mode.ImportOrExport
69+
case (ref: untpd.RefTree) :: _ =>
6770
if (ref.name.isTermName) Mode.Term
6871
else if (ref.name.isTypeName) Mode.Type
6972
else Mode.None
@@ -72,9 +75,8 @@ object Completion {
7275
if sel.imported.span.contains(pos.span) then Mode.ImportOrExport
7376
else Mode.None // Can't help completing the renaming
7477

75-
case (_: ImportOrExport) :: _ => Mode.ImportOrExport
78+
case (_: untpd.ImportOrExport) :: _ => Mode.ImportOrExport
7679
case _ => Mode.None
77-
}
7880

7981
/** When dealing with <errors> in varios palces we check to see if they are
8082
* due to incomplete backticks. If so, we ensure we get the full prefix
@@ -101,10 +103,13 @@ object Completion {
101103
case (sel: untpd.ImportSelector) :: _ =>
102104
completionPrefix(sel.imported :: Nil, pos)
103105

106+
case untpd.Ident(_) :: (sel: untpd.ImportSelector) :: _ if !sel.isGiven =>
107+
completionPrefix(sel.imported :: Nil, pos)
108+
104109
case (tree: untpd.ImportOrExport) :: _ =>
105-
tree.selectors.find(_.span.contains(pos.span)).map { selector =>
110+
tree.selectors.find(_.span.contains(pos.span)).map: selector =>
106111
completionPrefix(selector :: Nil, pos)
107-
}.getOrElse("")
112+
.getOrElse("")
108113

109114
// Foo.`se<TAB> will result in Select(Ident(Foo), <error>)
110115
case (select: untpd.Select) :: _ if select.name == nme.ERROR =>
@@ -118,27 +123,65 @@ object Completion {
118123
if (ref.name == nme.ERROR) ""
119124
else ref.name.toString.take(pos.span.point - ref.span.point)
120125

121-
case _ =>
122-
""
126+
case _ => ""
127+
123128
end completionPrefix
124129

125130
/** Inspect `path` to determine the offset where the completion result should be inserted. */
126-
def completionOffset(path: List[Tree]): Int =
127-
path match {
128-
case (ref: RefTree) :: _ => ref.span.point
131+
def completionOffset(untpdPath: List[untpd.Tree]): Int =
132+
untpdPath match {
133+
case (ref: untpd.RefTree) :: _ => ref.span.point
129134
case _ => 0
130135
}
131136

132-
private def computeCompletions(pos: SourcePosition, path: List[Tree])(using Context): (Int, List[Completion]) = {
133-
val mode = completionMode(path, pos)
134-
val rawPrefix = completionPrefix(path, pos)
137+
/** Some information about the trees is lost after Typer such as Extension method construct
138+
* is expanded into methods. In order to support completions in those cases
139+
* we have to rely on untyped trees and only when types are necessary use typed trees.
140+
*/
141+
def resolveTypedOrUntypedPath(tpdPath: List[Tree], pos: SourcePosition)(using Context): List[untpd.Tree] =
142+
lazy val untpdPath: List[untpd.Tree] = NavigateAST
143+
.pathTo(pos.span, List(ctx.compilationUnit.untpdTree), true).collect:
144+
case untpdTree: untpd.Tree => untpdTree
145+
146+
tpdPath match
147+
case (_: Bind) :: _ => tpdPath
148+
case (_: untpd.TypTree) :: _ => tpdPath
149+
case _ => untpdPath
150+
151+
/** Handle case when cursor position is inside extension method construct.
152+
* The extension method construct is then desugared into methods, and consturct parameters
153+
* are no longer a part of a typed tree, but instead are prepended to method parameters.
154+
*
155+
* @param untpdPath The typed or untyped path to the tree that is being completed
156+
* @param tpdPath The typed path that will be returned if no extension method construct is found
157+
* @param pos The cursor position
158+
*
159+
* @return Typed path to the parameter of the extension construct if found or tpdPath
160+
*/
161+
private def typeCheckExtensionConstructPath(
162+
untpdPath: List[untpd.Tree], tpdPath: List[Tree], pos: SourcePosition
163+
)(using Context): List[Tree] =
164+
untpdPath.collectFirst:
165+
case untpd.ExtMethods(paramss, _) =>
166+
val enclosingParam = paramss.flatten.find(_.span.contains(pos.span))
167+
enclosingParam.map: param =>
168+
ctx.typer.index(paramss.flatten)
169+
val typedEnclosingParam = ctx.typer.typed(param)
170+
Interactive.pathTo(typedEnclosingParam, pos.span)
171+
.flatten.getOrElse(tpdPath)
172+
173+
private def computeCompletions(pos: SourcePosition, tpdPath: List[Tree])(using Context): (Int, List[Completion]) =
174+
val path0 = resolveTypedOrUntypedPath(tpdPath, pos)
175+
val mode = completionMode(path0, pos)
176+
val rawPrefix = completionPrefix(path0, pos)
135177

136178
val hasBackTick = rawPrefix.headOption.contains('`')
137179
val prefix = if hasBackTick then rawPrefix.drop(1) else rawPrefix
138180

139181
val completer = new Completer(mode, prefix, pos)
140182

141-
val completions = path match {
183+
val adjustedPath = typeCheckExtensionConstructPath(path0, tpdPath, pos)
184+
val completions = adjustedPath match
142185
// Ignore synthetic select from `This` because in code it was `Ident`
143186
// See example in dotty.tools.languageserver.CompletionTest.syntheticThis
144187
case Select(qual @ This(_), _) :: _ if qual.span.isSynthetic => completer.scopeCompletions
@@ -147,21 +190,19 @@ object Completion {
147190
case (tree: ImportOrExport) :: _ => completer.directMemberCompletions(tree.expr)
148191
case (_: untpd.ImportSelector) :: Import(expr, _) :: _ => completer.directMemberCompletions(expr)
149192
case _ => completer.scopeCompletions
150-
}
151193

152194
val describedCompletions = describeCompletions(completions)
153195
val backtickedCompletions =
154196
describedCompletions.map(completion => backtickCompletions(completion, hasBackTick))
155197

156-
val offset = completionOffset(path)
198+
val offset = completionOffset(path0)
157199

158200
interactiv.println(i"""completion with pos = $pos,
159201
| prefix = ${completer.prefix},
160202
| term = ${completer.mode.is(Mode.Term)},
161203
| type = ${completer.mode.is(Mode.Type)}
162204
| results = $backtickedCompletions%, %""")
163205
(offset, backtickedCompletions)
164-
}
165206

166207
def backtickCompletions(completion: Completion, hasBackTick: Boolean) =
167208
if hasBackTick || needsBacktick(completion.label) then
@@ -174,17 +215,17 @@ object Completion {
174215
// https://github.com/scalameta/metals/blob/main/mtags/src/main/scala/scala/meta/internal/mtags/KeywordWrapper.scala
175216
// https://github.com/com-lihaoyi/Ammonite/blob/73a874173cd337f953a3edc9fb8cb96556638fdd/amm/util/src/main/scala/ammonite/util/Model.scala
176217
private def needsBacktick(s: String) =
177-
val chunks = s.split("_", -1)
218+
val chunks = s.split("_", -1).nn
178219

179220
val validChunks = chunks.zipWithIndex.forall { case (chunk, index) =>
180-
chunk.forall(Chars.isIdentifierPart) ||
181-
(chunk.forall(Chars.isOperatorPart) &&
221+
chunk.nn.forall(Chars.isIdentifierPart) ||
222+
(chunk.nn.forall(Chars.isOperatorPart) &&
182223
index == chunks.length - 1 &&
183224
!(chunks.lift(index - 1).contains("") && index - 1 == 0))
184225
}
185226

186227
val validStart =
187-
Chars.isIdentifierStart(s(0)) || chunks(0).forall(Chars.isOperatorPart)
228+
Chars.isIdentifierStart(s(0)) || chunks(0).nn.forall(Chars.isOperatorPart)
188229

189230
val valid = validChunks && validStart && !keywords.contains(s)
190231

@@ -216,7 +257,7 @@ object Completion {
216257
* For the results of all `xyzCompletions` methods term names and type names are always treated as different keys in the same map
217258
* and they never conflict with each other.
218259
*/
219-
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition) {
260+
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition):
220261
/** Completions for terms and types that are currently in scope:
221262
* the members of the current class, local definitions and the symbols that have been imported,
222263
* recursively adding completions from outer scopes.
@@ -230,7 +271,7 @@ object Completion {
230271
* (even if the import follows it syntactically)
231272
* - a more deeply nested import shadowing a member or a local definition causes an ambiguity
232273
*/
233-
def scopeCompletions(using context: Context): CompletionMap = {
274+
def scopeCompletions(using context: Context): CompletionMap =
234275
val mappings = collection.mutable.Map.empty[Name, List[ScopedDenotations]].withDefaultValue(List.empty)
235276
def addMapping(name: Name, denots: ScopedDenotations) =
236277
mappings(name) = mappings(name) :+ denots
@@ -302,7 +343,7 @@ object Completion {
302343
}
303344

304345
resultMappings
305-
}
346+
end scopeCompletions
306347

307348
/** Widen only those types which are applied or are exactly nothing
308349
*/
@@ -335,16 +376,16 @@ object Completion {
335376
/** Completions introduced by imports directly in this context.
336377
* Completions from outer contexts are not included.
337378
*/
338-
private def importedCompletions(using Context): CompletionMap = {
379+
private def importedCompletions(using Context): CompletionMap =
339380
val imp = ctx.importInfo
340381

341-
def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
342-
imp.site.member(name).alternatives
343-
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }
344-
345382
if imp == null then
346383
Map.empty
347384
else
385+
def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
386+
imp.site.member(name).alternatives
387+
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }
388+
348389
val givenImports = imp.importedImplicits
349390
.map { ref => (ref.implicitName: Name, ref.underlyingRef.denot.asSingleDenotation) }
350391
.filter((name, denot) => include(denot, name))
@@ -370,7 +411,7 @@ object Completion {
370411
}.toSeq.groupByName
371412

372413
givenImports ++ wildcardMembers ++ explicitMembers
373-
}
414+
end importedCompletions
374415

375416
/** Completions from implicit conversions including old style extensions using implicit classes */
376417
private def implicitConversionMemberCompletions(qual: Tree)(using Context): CompletionMap =
@@ -532,7 +573,6 @@ object Completion {
532573
extension [N <: Name](namedDenotations: Seq[(N, SingleDenotation)])
533574
@annotation.targetName("groupByNameTupled")
534575
def groupByName: CompletionMap = namedDenotations.groupMap((name, denot) => name)((name, denot) => denot)
535-
}
536576

537577
private type CompletionMap = Map[Name, Seq[SingleDenotation]]
538578

@@ -545,11 +585,11 @@ object Completion {
545585
* The completion mode: defines what kinds of symbols should be included in the completion
546586
* results.
547587
*/
548-
class Mode(val bits: Int) extends AnyVal {
588+
class Mode(val bits: Int) extends AnyVal:
549589
def is(other: Mode): Boolean = (bits & other.bits) == other.bits
550590
def |(other: Mode): Mode = new Mode(bits | other.bits)
551-
}
552-
object Mode {
591+
592+
object Mode:
553593
/** No symbol should be included */
554594
val None: Mode = new Mode(0)
555595

@@ -561,6 +601,4 @@ object Completion {
561601

562602
/** Both term and type symbols are allowed */
563603
val ImportOrExport: Mode = new Mode(4) | Term | Type
564-
}
565-
}
566604

compiler/src/dotty/tools/repl/ReplCompiler.scala

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ class ReplCompiler extends Compiler:
9393
end compile
9494

9595
final def typeOf(expr: String)(using state: State): Result[String] =
96-
typeCheck(expr).map { tree =>
96+
typeCheck(expr).map { (_, tpdTree) =>
9797
given Context = state.context
98-
tree.rhs match {
98+
tpdTree.rhs match {
9999
case Block(xs, _) => xs.last.tpe.widen.show
100100
case _ =>
101101
"""Couldn't compute the type of your expression, so sorry :(
@@ -129,7 +129,7 @@ class ReplCompiler extends Compiler:
129129
Iterator(sym) ++ sym.allOverriddenSymbols
130130
}
131131

132-
typeCheck(expr).map {
132+
typeCheck(expr).map { (_, tpdTree) => tpdTree match
133133
case ValDef(_, _, Block(stats, _)) if stats.nonEmpty =>
134134
val stat = stats.last.asInstanceOf[tpd.Tree]
135135
if (stat.tpe.isError) stat.tpe.show
@@ -152,7 +152,7 @@ class ReplCompiler extends Compiler:
152152
}
153153
}
154154

155-
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {
155+
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[(untpd.ValDef, tpd.ValDef)] = {
156156

157157
def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
158158
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
@@ -181,22 +181,32 @@ class ReplCompiler extends Compiler:
181181
}
182182
}
183183

184-
def unwrapped(tree: tpd.Tree, sourceFile: SourceFile)(using Context): Result[tpd.ValDef] = {
185-
def error: Result[tpd.ValDef] =
186-
List(new Diagnostic.Error(s"Invalid scala expression",
187-
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
184+
def error[Tree <: untpd.Tree](sourceFile: SourceFile): Result[Tree] =
185+
List(new Diagnostic.Error(s"Invalid scala expression",
186+
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
188187

188+
def unwrappedTypeTree(tree: tpd.Tree, sourceFile0: SourceFile)(using Context): Result[tpd.ValDef] = {
189189
import tpd._
190190
tree match {
191191
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
192192
tmpl.body
193193
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
194-
.getOrElse(error)
194+
.getOrElse(error[tpd.ValDef](sourceFile0))
195195
case _ =>
196-
error
196+
error[tpd.ValDef](sourceFile0)
197197
}
198198
}
199199

200+
def unwrappedUntypedTree(tree: untpd.Tree, sourceFile0: SourceFile)(using Context): Result[untpd.ValDef] =
201+
import untpd._
202+
tree match {
203+
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
204+
tmpl.body
205+
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
206+
.getOrElse(error[untpd.ValDef](sourceFile0))
207+
case _ =>
208+
error[untpd.ValDef](sourceFile0)
209+
}
200210

201211
val src = SourceFile.virtual("<typecheck>", expr)
202212
inContext(state.context.fresh
@@ -209,7 +219,10 @@ class ReplCompiler extends Compiler:
209219
ctx.run.nn.compileUnits(unit :: Nil, ctx)
210220

211221
if (errorsAllowed || !ctx.reporter.hasErrors)
212-
unwrapped(unit.tpdTree, src)
222+
for
223+
tpdTree <- unwrappedTypeTree(unit.tpdTree, src)
224+
untpdTree <- unwrappedUntypedTree(unit.untpdTree, src)
225+
yield untpdTree -> tpdTree
213226
else
214227
ctx.reporter.removeBufferedMessages.errors
215228
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,11 @@ class ReplDriver(settings: Array[String],
251251
given state: State = newRun(state0)
252252
compiler
253253
.typeCheck(expr, errorsAllowed = true)
254-
.map { tree =>
254+
.map { (untpdTree, tpdTree) =>
255255
val file = SourceFile.virtual("<completions>", expr, maybeIncomplete = true)
256256
val unit = CompilationUnit(file)(using state.context)
257-
unit.tpdTree = tree
257+
unit.untpdTree = untpdTree
258+
unit.tpdTree = tpdTree
258259
given Context = state.context.fresh.setCompilationUnit(unit)
259260
val srcPos = SourcePosition(file, Span(cursor))
260261
val completions = try Completion.completions(srcPos)._2 catch case NonFatal(_) => Nil

0 commit comments

Comments
 (0)