Skip to content

Commit aa04b2f

Browse files
committed
Set untpdTree in repl compilation unit for completions
1 parent 97f5f1c commit aa04b2f

File tree

5 files changed

+138
-20
lines changed

5 files changed

+138
-20
lines changed

compiler/src/dotty/tools/dotc/interactive/Completion.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,9 @@ object Completion {
136136
*/
137137
def pathBeforeDesugaring(path: List[Tree], pos: SourcePosition)(using Context): List[Tree] =
138138
val hasUntypedTree = path.headOption.forall(NavigateAST.untypedPath(_, exactMatch = true).nonEmpty)
139-
if hasUntypedTree then
140-
path
141-
else
142-
NavigateAST.untypedPath(pos.span).collect:
143-
case tree: untpd.Tree => tree
139+
if hasUntypedTree then path
140+
else NavigateAST.untypedPath(pos.span).collect:
141+
case tree: untpd.Tree => tree
144142

145143
private def computeCompletions(pos: SourcePosition, path: List[Tree])(using Context): (Int, List[Completion]) = {
146144
val path0 = pathBeforeDesugaring(path, pos)

compiler/src/dotty/tools/repl/ReplCompiler.scala

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ class ReplCompiler extends Compiler:
9393
end compile
9494

9595
final def typeOf(expr: String)(using state: State): Result[String] =
96-
typeCheck(expr).map { tree =>
96+
typeCheck(expr).map { (_, tpdTree) =>
9797
given Context = state.context
98-
tree.rhs match {
98+
tpdTree.rhs match {
9999
case Block(xs, _) => xs.last.tpe.widen.show
100100
case _ =>
101101
"""Couldn't compute the type of your expression, so sorry :(
@@ -129,7 +129,7 @@ class ReplCompiler extends Compiler:
129129
Iterator(sym) ++ sym.allOverriddenSymbols
130130
}
131131

132-
typeCheck(expr).map {
132+
typeCheck(expr).map { (_, tpdTree) => tpdTree match
133133
case ValDef(_, _, Block(stats, _)) if stats.nonEmpty =>
134134
val stat = stats.last.asInstanceOf[tpd.Tree]
135135
if (stat.tpe.isError) stat.tpe.show
@@ -152,7 +152,7 @@ class ReplCompiler extends Compiler:
152152
}
153153
}
154154

155-
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {
155+
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[(untpd.ValDef, tpd.ValDef)] = {
156156

157157
def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
158158
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
@@ -181,22 +181,32 @@ class ReplCompiler extends Compiler:
181181
}
182182
}
183183

184-
def unwrapped(tree: tpd.Tree, sourceFile: SourceFile)(using Context): Result[tpd.ValDef] = {
185-
def error: Result[tpd.ValDef] =
186-
List(new Diagnostic.Error(s"Invalid scala expression",
187-
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
184+
def error[Tree <: untpd.Tree](sourceFile: SourceFile): Result[Tree] =
185+
List(new Diagnostic.Error(s"Invalid scala expression",
186+
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
188187

188+
def unwrappedTypeTree(tree: tpd.Tree, sourceFile0: SourceFile)(using Context): Result[tpd.ValDef] = {
189189
import tpd._
190190
tree match {
191191
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
192192
tmpl.body
193193
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
194-
.getOrElse(error)
194+
.getOrElse(error[tpd.ValDef](sourceFile0))
195195
case _ =>
196-
error
196+
error[tpd.ValDef](sourceFile0)
197197
}
198198
}
199199

200+
def unwrappedUntypedTree(tree: untpd.Tree, sourceFile0: SourceFile)(using Context): Result[untpd.ValDef] =
201+
import untpd._
202+
tree match {
203+
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
204+
tmpl.body
205+
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
206+
.getOrElse(error[untpd.ValDef](sourceFile0))
207+
case _ =>
208+
error[untpd.ValDef](sourceFile0)
209+
}
200210

201211
val src = SourceFile.virtual("<typecheck>", expr)
202212
inContext(state.context.fresh
@@ -209,7 +219,10 @@ class ReplCompiler extends Compiler:
209219
ctx.run.nn.compileUnits(unit :: Nil, ctx)
210220

211221
if (errorsAllowed || !ctx.reporter.hasErrors)
212-
unwrapped(unit.tpdTree, src)
222+
for
223+
tpdTree <- unwrappedTypeTree(unit.tpdTree, src)
224+
untpdTree <- unwrappedUntypedTree(unit.untpdTree, src)
225+
yield untpdTree -> tpdTree
213226
else
214227
ctx.reporter.removeBufferedMessages.errors
215228
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,11 @@ class ReplDriver(settings: Array[String],
251251
given state: State = newRun(state0)
252252
compiler
253253
.typeCheck(expr, errorsAllowed = true)
254-
.map { tree =>
254+
.map { (untpdTree, tpdTree) =>
255255
val file = SourceFile.virtual("<completions>", expr, maybeIncomplete = true)
256256
val unit = CompilationUnit(file)(using state.context)
257-
unit.tpdTree = tree
257+
unit.untpdTree = untpdTree
258+
unit.tpdTree = tpdTree
258259
given Context = state.context.fresh.setCompilationUnit(unit)
259260
val srcPos = SourcePosition(file, Span(cursor))
260261
val completions = try Completion.completions(srcPos)._2 catch case NonFatal(_) => Nil

compiler/test/dotty/tools/repl/TabcompleteTests.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class TabcompleteTests extends ReplTest {
3232
assertEquals(List("apply"), comp)
3333
}
3434

35+
@Test def tabCompleteInExtensionDefinition = initially {
36+
val comp = tabComplete("extension (x: Lis")
37+
assertEquals(List("List"), comp)
38+
}
39+
3540
@Test def tabCompleteTwiceIn = {
3641
val src1 = "class Foo { def bar(xs: List[Int]) = xs.map"
3742
val src2 = "class Foo { def bar(xs: List[Int]) = xs.mapC"

presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionSuite.scala

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,16 @@ class CompletionSuite extends BaseCompletionSuite:
13301330
)
13311331

13321332
@Test def `extension-definition-scope` =
1333+
check(
1334+
"""|trait Foo
1335+
|object T:
1336+
| extension (x: Fo@@)
1337+
|""".stripMargin,
1338+
"""|Foo test
1339+
|""".stripMargin
1340+
)
1341+
1342+
@Test def `extension-definition-symbol-search` =
13331343
check(
13341344
"""|object T:
13351345
| extension (x: ListBuffe@@)
@@ -1339,11 +1349,102 @@ class CompletionSuite extends BaseCompletionSuite:
13391349
|""".stripMargin,
13401350
)
13411351

1342-
@Test def `extension-definition-symbol-search` =
1352+
@Test def `extension-definition-type-parameter` =
13431353
check(
13441354
"""|trait Foo
13451355
|object T:
1346-
| extension (x: Fo@@)
1356+
| extension [A <: Fo@@]
1357+
|""".stripMargin,
1358+
"""|Foo test
1359+
|""".stripMargin
1360+
)
1361+
1362+
@Test def `extension-definition-type-parameter-symbol-search` =
1363+
check(
1364+
"""|object T:
1365+
| extension [A <: ListBuffe@@]
1366+
|""".stripMargin,
1367+
"""|ListBuffer[T] - scala.collection.mutable
1368+
|ListBuffer - scala.collection.mutable
1369+
|""".stripMargin
1370+
)
1371+
1372+
@Test def `extension-definition-using-param-clause` =
1373+
check(
1374+
"""|trait Foo
1375+
|object T:
1376+
| extension (using Fo@@)
1377+
|""".stripMargin,
1378+
"""|Foo test
1379+
|""".stripMargin
1380+
)
1381+
1382+
1383+
@Test def `extension-definition-mix-1` =
1384+
check(
1385+
"""|trait Foo
1386+
|object T:
1387+
| extension (x: Int)(using Fo@@)
1388+
|""".stripMargin,
1389+
"""|Foo test
1390+
|""".stripMargin
1391+
)
1392+
1393+
@Test def `extension-definition-mix-2` =
1394+
check(
1395+
"""|trait Foo
1396+
|object T:
1397+
| extension (using Fo@@)(x: Int)(using Foo)
1398+
|""".stripMargin,
1399+
"""|Foo test
1400+
|""".stripMargin
1401+
)
1402+
1403+
@Test def `extension-definition-mix-3` =
1404+
check(
1405+
"""|trait Foo
1406+
|object T:
1407+
| extension (using Foo)(x: Int)(using Fo@@)
1408+
|""".stripMargin,
1409+
"""|Foo test
1410+
|""".stripMargin
1411+
)
1412+
1413+
@Test def `extension-definition-mix-4` =
1414+
check(
1415+
"""|trait Foo
1416+
|object T:
1417+
| extension [A](x: Fo@@)
1418+
|""".stripMargin,
1419+
"""|Foo test
1420+
|""".stripMargin
1421+
)
1422+
1423+
@Test def `extension-definition-mix-5` =
1424+
check(
1425+
"""|trait Foo
1426+
|object T:
1427+
| extension [A](using Fo@@)(x: Int)
1428+
|""".stripMargin,
1429+
"""|Foo test
1430+
|""".stripMargin
1431+
)
1432+
1433+
@Test def `extension-definition-mix-6` =
1434+
check(
1435+
"""|trait Foo
1436+
|object T:
1437+
| extension [A](using Foo)(x: Fo@@)
1438+
|""".stripMargin,
1439+
"""|Foo test
1440+
|""".stripMargin
1441+
)
1442+
1443+
@Test def `extension-definition-mix-7` =
1444+
check(
1445+
"""|trait Foo
1446+
|object T:
1447+
| extension [A](using Foo)(x: Fo@@)(using Fo@@)
13471448
|""".stripMargin,
13481449
"""|Foo test
13491450
|""".stripMargin

0 commit comments

Comments
 (0)