Skip to content

Commit 377ce48

Browse files
authored
Merge pull request #10473 from prolativ/extension-method-code-completion
Fix #10264: Add code completion for extension methods
2 parents c372fa1 + b9da708 commit 377ce48

File tree

5 files changed

+228
-7
lines changed

5 files changed

+228
-7
lines changed

compiler/src/dotty/tools/dotc/interactive/Completion.scala

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ import dotty.tools.dotc.core.Flags._
1313
import dotty.tools.dotc.core.Names.{Name, TermName}
1414
import dotty.tools.dotc.core.NameKinds.SimpleNameKind
1515
import dotty.tools.dotc.core.NameOps._
16-
import dotty.tools.dotc.core.Symbols.{NoSymbol, Symbol, defn}
16+
import dotty.tools.dotc.core.Symbols.{NoSymbol, Symbol, defn, newSymbol}
1717
import dotty.tools.dotc.core.Scopes
1818
import dotty.tools.dotc.core.StdNames.{nme, tpnme}
19+
import dotty.tools.dotc.core.TypeComparer
1920
import dotty.tools.dotc.core.TypeError
20-
import dotty.tools.dotc.core.Types.{NameFilter, NamedType, NoType, Type}
21+
import dotty.tools.dotc.core.Types.{ExprType, MethodType, NameFilter, NamedType, NoType, PolyType, Type}
2122
import dotty.tools.dotc.printing.Texts._
2223
import dotty.tools.dotc.util.{NameTransformer, NoSourcePosition, SourcePosition}
2324

@@ -118,7 +119,7 @@ object Completion {
118119

119120
if (buffer.mode != Mode.None)
120121
path match {
121-
case Select(qual, _) :: _ => buffer.addMemberCompletions(qual)
122+
case Select(qual, _) :: _ => buffer.addSelectionCompletions(path, qual)
122123
case Import(expr, _) :: _ => buffer.addMemberCompletions(expr) // TODO: distinguish given from plain imports
123124
case (_: untpd.ImportSelector) :: Import(expr, _) :: _ => buffer.addMemberCompletions(expr)
124125
case _ => buffer.addScopeCompletions
@@ -214,6 +215,66 @@ object Completion {
214215
.foreach(addAccessibleMembers)
215216
}
216217

218+
def addExtensionCompletions(path: List[Tree], qual: Tree)(using Context): Unit =
219+
def applyExtensionReceiver(methodSymbol: Symbol, methodName: TermName): Symbol = {
220+
val newMethodType = methodSymbol.info match {
221+
case mt: MethodType =>
222+
mt.resultType match {
223+
case resType: MethodType => resType
224+
case resType => ExprType(resType)
225+
}
226+
case pt: PolyType =>
227+
PolyType(pt.paramNames)(_ => pt.paramInfos, _ => pt.resultType.resultType)
228+
}
229+
230+
newSymbol(owner = qual.symbol, methodName, methodSymbol.flags, newMethodType)
231+
}
232+
233+
val matchingNamePrefix = completionPrefix(path, pos)
234+
235+
def extractDefinedExtensionMethods(types: Seq[Type]) =
236+
types
237+
.flatMap(_.membersBasedOnFlags(required = ExtensionMethod, excluded = EmptyFlags))
238+
.collect{ denot =>
239+
denot.name.toTermName match {
240+
case name if name.startsWith(matchingNamePrefix) => (denot.symbol, name)
241+
}
242+
}
243+
244+
// There are four possible ways for an extension method to be applicable:
245+
246+
// 1. The extension method is visible under a simple name, by being defined or inherited or imported in a scope enclosing the reference.
247+
val extMethodsInScope =
248+
val buf = completionBuffer(path, pos)
249+
buf.addScopeCompletions
250+
buf.completions.mappings.toList.flatMap {
251+
case (termName, symbols) => symbols.map(s => (s, termName))
252+
}
253+
254+
// 2. The extension method is a member of some given instance that is visible at the point of the reference.
255+
val givensInScope = ctx.implicits.eligible(defn.AnyType).map(_.implicitRef.underlyingRef)
256+
val extMethodsFromGivensInScope = extractDefinedExtensionMethods(givensInScope)
257+
258+
// 3. The reference is of the form r.m and the extension method is defined in the implicit scope of the type of r.
259+
val implicitScopeCompanions = ctx.run.implicitScope(qual.tpe).companionRefs.showAsList
260+
val extMethodsFromImplicitScope = extractDefinedExtensionMethods(implicitScopeCompanions)
261+
262+
// 4. The reference is of the form r.m and the extension method is defined in some given instance in the implicit scope of the type of r.
263+
val givensInImplicitScope = implicitScopeCompanions.flatMap(_.membersBasedOnFlags(required = Given, excluded = EmptyFlags)).map(_.symbol.info)
264+
val extMethodsFromGivensInImplicitScope = extractDefinedExtensionMethods(givensInImplicitScope)
265+
266+
val availableExtMethods = extMethodsFromGivensInImplicitScope ++ extMethodsFromImplicitScope ++ extMethodsFromGivensInScope ++ extMethodsInScope
267+
val extMethodsWithAppliedReceiver = availableExtMethods.collect {
268+
case (symbol, termName) if ctx.typer.isApplicableExtensionMethod(symbol.termRef, qual.tpe) =>
269+
applyExtensionReceiver(symbol, termName)
270+
}
271+
272+
for (symbol <- extMethodsWithAppliedReceiver) do add(symbol, symbol.name)
273+
274+
def addSelectionCompletions(path: List[Tree], qual: Tree)(using Context): Unit =
275+
addExtensionCompletions(path, qual)
276+
addMemberCompletions(qual)
277+
217278
/**
218279
* If `sym` exists, no symbol with the same name is already included, and it satisfies the
219280
* inclusion filter, then add it to the completions.

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,4 +2176,9 @@ trait Applications extends Compatibility {
21762176
report.error(em"not an extension method: $methodRef", receiver.srcPos)
21772177
app
21782178
}
2179+
2180+
def isApplicableExtensionMethod(ref: TermRef, receiver: Type)(using Context) =
2181+
ref.symbol.is(ExtensionMethod)
2182+
&& !receiver.isBottomType
2183+
&& isApplicableMethodRef(ref, receiver :: Nil, WildcardType)
21792184
}

compiler/src/dotty/tools/dotc/typer/ImportSuggestions.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,7 @@ trait ImportSuggestions:
222222
site.member(name)
223223
.alternatives
224224
.map(mbr => TermRef(site, mbr.symbol))
225-
.filter(ref =>
226-
ref.symbol.is(ExtensionMethod)
227-
&& isApplicableMethodRef(ref, argType :: Nil, WildcardType))
225+
.filter(ref => ctx.typer.isApplicableExtensionMethod(ref, argType))
228226
.headOption
229227

230228
try

compiler/test/dotty/tools/repl/TabcompleteTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class TabcompleteTests extends ReplTest {
116116
val comp = tabComplete("(null: AnyRef).")
117117
assertEquals(
118118
List("!=", "##", "->", "==", "asInstanceOf", "clone", "ensuring", "eq", "equals", "finalize", "formatted",
119-
"getClass", "hashCode", "isInstanceOf", "ne", "notify", "notifyAll", "synchronized", "toString", "wait", ""),
119+
"getClass", "hashCode", "isInstanceOf", "ne", "nn", "notify", "notifyAll", "synchronized", "toString", "wait", ""),
120120
comp.distinct.sorted)
121121
}
122122

language-server/test/dotty/tools/languageserver/CompletionTest.scala

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,161 @@ class CompletionTest {
288288
|import Foo.b$m1""".withSource
289289
.completion(m1, Set(("bar", Field, "type and lazy value bar")))
290290
}
291+
292+
@Test def completeExtensionMethodWithoutParameter: Unit = {
293+
code"""object Foo
294+
|extension (foo: Foo.type) def xxxx = 1
295+
|object Main { Foo.xx${m1} }""".withSource
296+
.completion(m1, Set(("xxxx", Method, "=> Int")))
297+
}
298+
299+
@Test def completeExtensionMethodWithParameter: Unit = {
300+
code"""object Foo
301+
|extension (foo: Foo.type) def xxxx(i: Int) = i
302+
|object Main { Foo.xx${m1} }""".withSource
303+
.completion(m1, Set(("xxxx", Method, "(i: Int): Int")))
304+
}
305+
306+
@Test def completeExtensionMethodWithTypeParameter: Unit = {
307+
code"""object Foo
308+
|extension [A](foo: Foo.type) def xxxx: Int = 1
309+
|object Main { Foo.xx${m1} }""".withSource
310+
.completion(m1, Set(("xxxx", Method, "[A] => Int")))
311+
}
312+
313+
@Test def completeExtensionMethodWithParameterAndTypeParameter: Unit = {
314+
code"""object Foo
315+
|extension [A](foo: Foo.type) def xxxx(a: A) = a
316+
|object Main { Foo.xx${m1} }""".withSource
317+
.completion(m1, Set(("xxxx", Method, "[A](a: A): A")))
318+
}
319+
320+
@Test def completeExtensionMethodFromExtenionWithAUsingSection: Unit = {
321+
code"""object Foo
322+
|trait Bar
323+
|trait Baz
324+
|given Bar = new Bar {}
325+
|given Baz = new Baz {}
326+
|extension (foo: Foo.type)(using Bar, Baz) def xxxx = 1
327+
|object Main { Foo.xx${m1} }""".withSource
328+
.completion(m1, Set(("xxxx", Method, "(using x$1: Bar, x$2: Baz): Int")))
329+
}
330+
331+
@Test def completeExtensionMethodFromExtenionWithMultipleUsingSections: Unit = {
332+
code"""object Foo
333+
|trait Bar
334+
|trait Baz
335+
|given Bar = new Bar {}
336+
|given Baz = new Baz {}
337+
|extension (foo: Foo.type)(using Bar)(using Baz) def xxxx = 1
338+
|object Main { Foo.xx${m1} }""".withSource
339+
.completion(m1, Set(("xxxx", Method, "(using x$1: Bar)(using x$2: Baz): Int")))
340+
}
341+
342+
@Test def completeInheritedExtensionMethod: Unit = {
343+
code"""object Foo
344+
|trait FooOps {
345+
| extension (foo: Foo.type) def xxxx = 1
346+
|}
347+
|object Main extends FooOps { Foo.xx${m1} }""".withSource
348+
.completion(m1, Set(("xxxx", Method, "=> Int")))
349+
}
350+
351+
@Test def completeRenamedExtensionMethod: Unit = {
352+
code"""object Foo
353+
|object FooOps {
354+
| extension (foo: Foo.type) def xxxx = 1
355+
|}
356+
|import FooOps.{xxxx => yyyy}
357+
|object Main { Foo.yy${m1} }""".withSource
358+
.completion(m1, Set(("yyyy", Method, "=> Int")))
359+
}
360+
361+
@Test def completeExtensionMethodFromGivenInstanceDefinedInScope: Unit = {
362+
code"""object Foo
363+
|trait FooOps
364+
|given FooOps {
365+
| extension (foo: Foo.type) def xxxx = 1
366+
|}
367+
|object Main { Foo.xx${m1} }""".withSource
368+
.completion(m1, Set(("xxxx", Method, "=> Int")))
369+
}
370+
371+
@Test def completeExtensionMethodFromImportedGivenInstance: Unit = {
372+
code"""object Foo
373+
|trait FooOps
374+
|object Bar {
375+
| given FooOps {
376+
| extension (foo: Foo.type) def xxxx = 1
377+
| }
378+
|}
379+
|import Bar.given
380+
|object Main { Foo.xx${m1} }""".withSource
381+
.completion(m1, Set(("xxxx", Method, "=> Int")))
382+
}
383+
384+
@Test def completeExtensionMethodFromImplicitScope: Unit = {
385+
code"""case class Foo(i: Int)
386+
|object Foo {
387+
| extension (foo: Foo) def xxxx = foo.i
388+
|}
389+
|object Main { Foo(123).xx${m1} }""".withSource
390+
.completion(m1, Set(("xxxx", Method, "=> Int")))
391+
}
392+
393+
@Test def completeExtensionMethodFromGivenInImplicitScope: Unit = {
394+
code"""trait Bar
395+
|case class Foo(i: Int)
396+
|object Foo {
397+
| given Bar {
398+
| extension (foo: Foo) def xxxx = foo.i
399+
| }
400+
|}
401+
|object Main { Foo(123).xx${m1} }""".withSource
402+
.completion(m1, Set(("xxxx", Method, "=> Int")))
403+
}
404+
405+
@Test def completeExtensionMethodOnResultOfImplicitConversion: Unit = {
406+
code"""import scala.language.implicitConversions
407+
|case class Foo(i: Int)
408+
|extension (foo: Foo) def xxxx = foo.i
409+
|given Conversion[Int, Foo] = Foo(_)
410+
|object Main { 123.xx${m1} }""".withSource
411+
.completion(m1, Set(("xxxx", Method, "=> Int")))
412+
}
413+
414+
@Test def dontCompleteExtensionMethodWithMismatchedName: Unit = {
415+
code"""object Foo
416+
|extension (foo: Foo.type) def xxxx = 1
417+
|object Main { Foo.yy${m1} }""".withSource
418+
.completion(m1, Set())
419+
}
420+
421+
@Test def preferNormalMethodToExtensionMethod: Unit = {
422+
code"""object Foo {
423+
| def xxxx = "abcd"
424+
|}
425+
|object FooOps {
426+
| extension (foo: Foo.type) def xxxx = 1
427+
|}
428+
|object Main { Foo.xx${m1} }""".withSource
429+
.completion(m1, Set(("xxxx", Method, "=> String")))
430+
}
431+
432+
@Test def preferExtensionMethodFromExplicitScope: Unit = {
433+
code"""object Foo
434+
|extension (foo: Foo.type) def xxxx = 1
435+
|object FooOps {
436+
| extension (foo: Foo.type) def xxxx = "abcd"
437+
|}
438+
|object Main { Foo.xx${m1} }""".withSource
439+
.completion(m1, Set(("xxxx", Method, "=> Int")))
440+
}
441+
442+
@Test def dontCompleteInapplicableExtensionMethod: Unit = {
443+
code"""case class Foo[A](a: A)
444+
|extension (foo: Foo[Int]) def xxxx = foo.a
445+
|object Main { Foo("abc").xx${m1} }""".withSource
446+
.completion(m1, Set())
447+
}
291448
}

0 commit comments

Comments
 (0)