Skip to content

Commit d64894b

Browse files
committed
Set untpdTree in repl compilation unit for completions
1 parent ae1b409 commit d64894b

File tree

5 files changed

+138
-20
lines changed

5 files changed

+138
-20
lines changed

compiler/src/dotty/tools/dotc/interactive/Completion.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,9 @@ object Completion {
136136
*/
137137
def pathBeforeDesugaring(path: List[Tree], pos: SourcePosition)(using Context): List[Tree] =
138138
val hasUntypedTree = path.headOption.forall(NavigateAST.untypedPath(_, exactMatch = true).nonEmpty)
139-
if hasUntypedTree then
140-
path
141-
else
142-
NavigateAST.untypedPath(pos.span).collect:
143-
case tree: untpd.Tree => tree
139+
if hasUntypedTree then path
140+
else NavigateAST.untypedPath(pos.span).collect:
141+
case tree: untpd.Tree => tree
144142

145143
private def computeCompletions(pos: SourcePosition, path: List[Tree])(using Context): (Int, List[Completion]) = {
146144
val path0 = pathBeforeDesugaring(path, pos)

compiler/src/dotty/tools/repl/ReplCompiler.scala

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ class ReplCompiler extends Compiler:
9393
end compile
9494

9595
final def typeOf(expr: String)(using state: State): Result[String] =
96-
typeCheck(expr).map { tree =>
96+
typeCheck(expr).map { (_, tpdTree) =>
9797
given Context = state.context
98-
tree.rhs match {
98+
tpdTree.rhs match {
9999
case Block(xs, _) => xs.last.tpe.widen.show
100100
case _ =>
101101
"""Couldn't compute the type of your expression, so sorry :(
@@ -129,7 +129,7 @@ class ReplCompiler extends Compiler:
129129
Iterator(sym) ++ sym.allOverriddenSymbols
130130
}
131131

132-
typeCheck(expr).map {
132+
typeCheck(expr).map { (_, tpdTree) => tpdTree match
133133
case ValDef(_, _, Block(stats, _)) if stats.nonEmpty =>
134134
val stat = stats.last.asInstanceOf[tpd.Tree]
135135
if (stat.tpe.isError) stat.tpe.show
@@ -152,7 +152,7 @@ class ReplCompiler extends Compiler:
152152
}
153153
}
154154

155-
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {
155+
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[(untpd.ValDef, tpd.ValDef)] = {
156156

157157
def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
158158
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
@@ -181,22 +181,32 @@ class ReplCompiler extends Compiler:
181181
}
182182
}
183183

184-
def unwrapped(tree: tpd.Tree, sourceFile: SourceFile)(using Context): Result[tpd.ValDef] = {
185-
def error: Result[tpd.ValDef] =
186-
List(new Diagnostic.Error(s"Invalid scala expression",
187-
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
184+
def error[Tree <: untpd.Tree](sourceFile: SourceFile): Result[Tree] =
185+
List(new Diagnostic.Error(s"Invalid scala expression",
186+
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
188187

188+
def unwrappedTypeTree(tree: tpd.Tree, sourceFile0: SourceFile)(using Context): Result[tpd.ValDef] = {
189189
import tpd._
190190
tree match {
191191
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
192192
tmpl.body
193193
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
194-
.getOrElse(error)
194+
.getOrElse(error[tpd.ValDef](sourceFile0))
195195
case _ =>
196-
error
196+
error[tpd.ValDef](sourceFile0)
197197
}
198198
}
199199

200+
def unwrappedUntypedTree(tree: untpd.Tree, sourceFile0: SourceFile)(using Context): Result[untpd.ValDef] =
201+
import untpd._
202+
tree match {
203+
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
204+
tmpl.body
205+
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
206+
.getOrElse(error[untpd.ValDef](sourceFile0))
207+
case _ =>
208+
error[untpd.ValDef](sourceFile0)
209+
}
200210

201211
val src = SourceFile.virtual("<typecheck>", expr)
202212
inContext(state.context.fresh
@@ -209,7 +219,10 @@ class ReplCompiler extends Compiler:
209219
ctx.run.nn.compileUnits(unit :: Nil, ctx)
210220

211221
if (errorsAllowed || !ctx.reporter.hasErrors)
212-
unwrapped(unit.tpdTree, src)
222+
for
223+
tpdTree <- unwrappedTypeTree(unit.tpdTree, src)
224+
untpdTree <- unwrappedUntypedTree(unit.untpdTree, src)
225+
yield untpdTree -> tpdTree
213226
else
214227
ctx.reporter.removeBufferedMessages.errors
215228
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,11 @@ class ReplDriver(settings: Array[String],
251251
given state: State = newRun(state0)
252252
compiler
253253
.typeCheck(expr, errorsAllowed = true)
254-
.map { tree =>
254+
.map { (untpdTree, tpdTree) =>
255255
val file = SourceFile.virtual("<completions>", expr, maybeIncomplete = true)
256256
val unit = CompilationUnit(file)(using state.context)
257-
unit.tpdTree = tree
257+
unit.untpdTree = untpdTree
258+
unit.tpdTree = tpdTree
258259
given Context = state.context.fresh.setCompilationUnit(unit)
259260
val srcPos = SourcePosition(file, Span(cursor))
260261
val completions = try Completion.completions(srcPos)._2 catch case NonFatal(_) => Nil

compiler/test/dotty/tools/repl/TabcompleteTests.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class TabcompleteTests extends ReplTest {
3232
assertEquals(List("apply"), comp)
3333
}
3434

35+
@Test def tabCompleteInExtensionDefinition = initially {
36+
val comp = tabComplete("extension (x: Lis")
37+
assertEquals(List("List"), comp)
38+
}
39+
3540
@Test def tabCompleteTwiceIn = {
3641
val src1 = "class Foo { def bar(xs: List[Int]) = xs.map"
3742
val src2 = "class Foo { def bar(xs: List[Int]) = xs.mapC"

presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionSuite.scala

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,16 @@ class CompletionSuite extends BaseCompletionSuite:
12531253
)
12541254

12551255
@Test def `extension-definition-scope` =
1256+
check(
1257+
"""|trait Foo
1258+
|object T:
1259+
| extension (x: Fo@@)
1260+
|""".stripMargin,
1261+
"""|Foo test
1262+
|""".stripMargin
1263+
)
1264+
1265+
@Test def `extension-definition-symbol-search` =
12561266
check(
12571267
"""|object T:
12581268
| extension (x: ListBuffe@@)
@@ -1262,11 +1272,102 @@ class CompletionSuite extends BaseCompletionSuite:
12621272
|""".stripMargin,
12631273
)
12641274

1265-
@Test def `extension-definition-symbol-search` =
1275+
@Test def `extension-definition-type-parameter` =
12661276
check(
12671277
"""|trait Foo
12681278
|object T:
1269-
| extension (x: Fo@@)
1279+
| extension [A <: Fo@@]
1280+
|""".stripMargin,
1281+
"""|Foo test
1282+
|""".stripMargin
1283+
)
1284+
1285+
@Test def `extension-definition-type-parameter-symbol-search` =
1286+
check(
1287+
"""|object T:
1288+
| extension [A <: ListBuffe@@]
1289+
|""".stripMargin,
1290+
"""|ListBuffer[T] - scala.collection.mutable
1291+
|ListBuffer - scala.collection.mutable
1292+
|""".stripMargin
1293+
)
1294+
1295+
@Test def `extension-definition-using-param-clause` =
1296+
check(
1297+
"""|trait Foo
1298+
|object T:
1299+
| extension (using Fo@@)
1300+
|""".stripMargin,
1301+
"""|Foo test
1302+
|""".stripMargin
1303+
)
1304+
1305+
1306+
@Test def `extension-definition-mix-1` =
1307+
check(
1308+
"""|trait Foo
1309+
|object T:
1310+
| extension (x: Int)(using Fo@@)
1311+
|""".stripMargin,
1312+
"""|Foo test
1313+
|""".stripMargin
1314+
)
1315+
1316+
@Test def `extension-definition-mix-2` =
1317+
check(
1318+
"""|trait Foo
1319+
|object T:
1320+
| extension (using Fo@@)(x: Int)(using Foo)
1321+
|""".stripMargin,
1322+
"""|Foo test
1323+
|""".stripMargin
1324+
)
1325+
1326+
@Test def `extension-definition-mix-3` =
1327+
check(
1328+
"""|trait Foo
1329+
|object T:
1330+
| extension (using Foo)(x: Int)(using Fo@@)
1331+
|""".stripMargin,
1332+
"""|Foo test
1333+
|""".stripMargin
1334+
)
1335+
1336+
@Test def `extension-definition-mix-4` =
1337+
check(
1338+
"""|trait Foo
1339+
|object T:
1340+
| extension [A](x: Fo@@)
1341+
|""".stripMargin,
1342+
"""|Foo test
1343+
|""".stripMargin
1344+
)
1345+
1346+
@Test def `extension-definition-mix-5` =
1347+
check(
1348+
"""|trait Foo
1349+
|object T:
1350+
| extension [A](using Fo@@)(x: Int)
1351+
|""".stripMargin,
1352+
"""|Foo test
1353+
|""".stripMargin
1354+
)
1355+
1356+
@Test def `extension-definition-mix-6` =
1357+
check(
1358+
"""|trait Foo
1359+
|object T:
1360+
| extension [A](using Foo)(x: Fo@@)
1361+
|""".stripMargin,
1362+
"""|Foo test
1363+
|""".stripMargin
1364+
)
1365+
1366+
@Test def `extension-definition-mix-7` =
1367+
check(
1368+
"""|trait Foo
1369+
|object T:
1370+
| extension [A](using Foo)(x: Fo@@)(using Fo@@)
12701371
|""".stripMargin,
12711372
"""|Foo test
12721373
|""".stripMargin

0 commit comments

Comments
 (0)