Skip to content

Commit 17ba942

Browse files
committed
Fix #1372: Add handler for PatDefs to REPL
1 parent 765aecb commit 17ba942

File tree

3 files changed

+73
-21
lines changed

3 files changed

+73
-21
lines changed

src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,11 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
447447
}
448448
}
449449

450+
/** Fold `f` over all tree nodes, in depth-first, prefix order */
451+
class UntypedDeepFolder[X](f: (X, Tree) => X) extends UntypedTreeAccumulator[X] {
452+
def apply(x: X, tree: Tree)(implicit ctx: Context): X = foldOver(f(x, tree), tree)
453+
}
454+
450455
override def rename(tree: NameTree, newName: Name)(implicit ctx: Context): tree.ThisTree[Untyped] = tree match {
451456
case t: PolyTypeDef =>
452457
cpy.PolyTypeDef(t)(newName.asTypeName, t.tparams, t.rhs).asInstanceOf[tree.ThisTree[Untyped]]

src/dotty/tools/dotc/repl/CompilingInterpreter.scala

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class CompilingInterpreter(
222222
if (delayOutput)
223223
previousOutput ++= resultStrings.map(clean)
224224
else if (printResults || !succeeded)
225-
resultStrings.map(x => out.print(clean(x)))
225+
resultStrings.foreach(x => out.print(clean(x)))
226226
if (succeeded) {
227227
prevRequests += req
228228
Interpreter.Success
@@ -328,6 +328,7 @@ class CompilingInterpreter(
328328
private def chooseHandler(stat: Tree): StatementHandler = stat match {
329329
case stat: DefDef => new DefHandler(stat)
330330
case stat: ValDef => new ValHandler(stat)
331+
case stat: PatDef => new PatHandler(stat)
331332
case stat @ Assign(Ident(_), _) => new AssignHandler(stat)
332333
case stat: ModuleDef => new ModuleHandler(stat)
333334
case stat: TypeDef if stat.isClassDef => new ClassHandler(stat)
@@ -662,29 +663,65 @@ class CompilingInterpreter(
662663

663664
private class GenericHandler(statement: Tree) extends StatementHandler(statement)
664665

665-
private class ValHandler(statement: ValDef) extends StatementHandler(statement) {
666-
override val boundNames = List(statement.name)
666+
private abstract class ValOrPatHandler(statement: Tree)
667+
extends StatementHandler(statement) {
668+
override val boundNames: List[Name] = _boundNames
667669
override def valAndVarNames = boundNames
668670

669671
override def resultExtractionCode(req: Request, code: PrintWriter): Unit = {
670-
val vname = statement.name
671-
if (!statement.mods.is(Flags.AccessFlags) &&
672-
!(isGeneratedVarName(vname.toString) &&
673-
req.typeOf(vname.encode) == "Unit")) {
674-
val prettyName = vname.decode
675-
code.print(" + \"" + prettyName + ": " +
676-
string2code(req.typeOf(vname)) +
677-
" = \" + " +
678-
" (if(" +
679-
req.fullPath(vname) +
680-
".asInstanceOf[AnyRef] != null) " +
681-
" ((if(" +
682-
req.fullPath(vname) +
683-
".toString().contains('\\n')) " +
684-
" \"\\n\" else \"\") + " +
685-
req.fullPath(vname) + ".toString() + \"\\n\") else \"null\\n\") ")
686-
}
672+
if (!shouldShowResult(req)) return
673+
val resultExtractors = boundNames.map(name => resultExtractor(req, name))
674+
code.print(resultExtractors.mkString(""))
675+
}
676+
677+
private def resultExtractor(req: Request, varName: Name): String = {
678+
val prettyName = varName.decode
679+
val varType = string2code(req.typeOf(varName))
680+
val fullPath = req.fullPath(varName)
681+
682+
s""" + "$prettyName: $varType = " + {
683+
| if ($fullPath.asInstanceOf[AnyRef] != null) {
684+
| (if ($fullPath.toString().contains('\\n')) "\\n" else "") +
685+
| $fullPath.toString() + "\\n"
686+
| } else {
687+
| "null\\n"
688+
| }
689+
|}""".stripMargin
687690
}
691+
692+
protected def _boundNames: List[Name]
693+
protected def shouldShowResult(req: Request): Boolean
694+
}
695+
696+
private class ValHandler(statement: ValDef) extends ValOrPatHandler(statement) {
697+
override def _boundNames = List(statement.name)
698+
699+
override def shouldShowResult(req: Request): Boolean =
700+
!statement.mods.is(Flags.AccessFlags) &&
701+
!(isGeneratedVarName(statement.name.toString) &&
702+
req.typeOf(statement.name.encode) == "Unit")
703+
}
704+
705+
706+
private class PatHandler(statement: PatDef) extends ValOrPatHandler(statement) {
707+
override def _boundNames = statement.pats.flatMap(findVariableNames)
708+
709+
override def shouldShowResult(req: Request): Boolean =
710+
!statement.mods.is(Flags.AccessFlags)
711+
712+
private def findVariableNames(tree: Tree): List[Name] = tree match {
713+
case Ident(name) if name.toString != "_" => List(name)
714+
case _ => VariableNameFinder(Nil, tree).reverse
715+
}
716+
717+
private object VariableNameFinder extends UntypedDeepFolder[List[Name]](
718+
(acc: List[Name], t: Tree) => t match {
719+
case _: BackquotedIdent => acc
720+
case Ident(name) if name.isVariableName && name.toString != "_" => name :: acc
721+
case Bind(name, _) if name.isVariableName => name :: acc
722+
case _ => acc
723+
}
724+
)
688725
}
689726

690727
private class DefHandler(defDef: DefDef) extends StatementHandler(defDef) {
@@ -836,7 +873,7 @@ class CompilingInterpreter(
836873
val stringWriter = new StringWriter()
837874
val stream = new NewLinePrintWriter(stringWriter)
838875
writer(stream)
839-
stream.close
876+
stream.close()
840877
stringWriter.toString
841878
}
842879

tests/repl/patdef.check

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
scala> val Const,x = 0
2+
Const: Int = 0
3+
x: Int = 0
4+
scala> val (Const, List(`x`, _, a), b) = (0, List(0, 1337, 1), 2)
5+
a: Int = 1
6+
b: Int = 2
7+
scala> val a@b = 0
8+
a: Int @unchecked = 0
9+
b: Int @unchecked = 0
10+
scala> :quit

0 commit comments

Comments
 (0)