Skip to content

Commit 2a17a7c

Browse files
committed
Fix #1372: Add handler for PatDefs to REPL
1 parent 765aecb commit 2a17a7c

File tree

3 files changed

+79
-21
lines changed

3 files changed

+79
-21
lines changed

src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,11 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
447447
}
448448
}
449449

450+
/** Fold `f` over all tree nodes, in depth-first, prefix order */
451+
class UntypedDeepFolder[X](f: (X, Tree) => X) extends UntypedTreeAccumulator[X] {
452+
def apply(x: X, tree: Tree)(implicit ctx: Context): X = foldOver(f(x, tree), tree)
453+
}
454+
450455
override def rename(tree: NameTree, newName: Name)(implicit ctx: Context): tree.ThisTree[Untyped] = tree match {
451456
case t: PolyTypeDef =>
452457
cpy.PolyTypeDef(t)(newName.asTypeName, t.tparams, t.rhs).asInstanceOf[tree.ThisTree[Untyped]]

src/dotty/tools/dotc/repl/CompilingInterpreter.scala

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class CompilingInterpreter(
222222
if (delayOutput)
223223
previousOutput ++= resultStrings.map(clean)
224224
else if (printResults || !succeeded)
225-
resultStrings.map(x => out.print(clean(x)))
225+
resultStrings.foreach(x => out.print(clean(x)))
226226
if (succeeded) {
227227
prevRequests += req
228228
Interpreter.Success
@@ -328,6 +328,7 @@ class CompilingInterpreter(
328328
private def chooseHandler(stat: Tree): StatementHandler = stat match {
329329
case stat: DefDef => new DefHandler(stat)
330330
case stat: ValDef => new ValHandler(stat)
331+
case stat: PatDef => new PatHandler(stat)
331332
case stat @ Assign(Ident(_), _) => new AssignHandler(stat)
332333
case stat: ModuleDef => new ModuleHandler(stat)
333334
case stat: TypeDef if stat.isClassDef => new ClassHandler(stat)
@@ -662,29 +663,71 @@ class CompilingInterpreter(
662663

663664
private class GenericHandler(statement: Tree) extends StatementHandler(statement)
664665

665-
private class ValHandler(statement: ValDef) extends StatementHandler(statement) {
666-
override val boundNames = List(statement.name)
666+
private abstract class ValOrPatHandler(statement: Tree)
667+
extends StatementHandler(statement) {
668+
override val boundNames: List[Name] = _boundNames
667669
override def valAndVarNames = boundNames
668670

669671
override def resultExtractionCode(req: Request, code: PrintWriter): Unit = {
670-
val vname = statement.name
671-
if (!statement.mods.is(Flags.AccessFlags) &&
672-
!(isGeneratedVarName(vname.toString) &&
673-
req.typeOf(vname.encode) == "Unit")) {
674-
val prettyName = vname.decode
675-
code.print(" + \"" + prettyName + ": " +
676-
string2code(req.typeOf(vname)) +
677-
" = \" + " +
678-
" (if(" +
679-
req.fullPath(vname) +
680-
".asInstanceOf[AnyRef] != null) " +
681-
" ((if(" +
682-
req.fullPath(vname) +
683-
".toString().contains('\\n')) " +
684-
" \"\\n\" else \"\") + " +
685-
req.fullPath(vname) + ".toString() + \"\\n\") else \"null\\n\") ")
686-
}
672+
if (!shouldShowResult(req)) return
673+
674+
val resultExtractors =
675+
for (varName <- boundNames)
676+
yield resultExtractionCode(req, varName)
677+
code.print(resultExtractors.mkString(""))
678+
}
679+
680+
private def resultExtractionCode(req: Request, varName: Name): String = {
681+
def if_(condition: String)(thenBranch: String)(elseBranch: String): String =
682+
s"(if ($condition) {$thenBranch} else {$elseBranch})"
683+
684+
val prettyName = varName.decode
685+
val varType = string2code(req.typeOf(varName))
686+
val fullPath = req.fullPath(varName)
687+
688+
s""" + "$prettyName: $varType = " + """ +
689+
if_(s"$fullPath.asInstanceOf[AnyRef] != null") {
690+
if_(s"$fullPath.toString().contains('\\n')") { "\"\\n\"" }
691+
/* else */ { "\"\"" } +
692+
s""" + $fullPath.toString() + "\\n" """
693+
} /* else */ {
694+
"\"null\\n\""
695+
}
687696
}
697+
698+
protected def _boundNames: List[Name]
699+
protected def shouldShowResult(req: Request): Boolean
700+
}
701+
702+
private class ValHandler(statement: ValDef) extends ValOrPatHandler(statement) {
703+
override def _boundNames = List(statement.name)
704+
705+
override def shouldShowResult(req: Request): Boolean =
706+
!statement.mods.is(Flags.AccessFlags) &&
707+
!(isGeneratedVarName(statement.name.toString) &&
708+
req.typeOf(statement.name.encode) == "Unit")
709+
}
710+
711+
712+
private class PatHandler(statement: PatDef) extends ValOrPatHandler(statement) {
713+
override def _boundNames = statement.pats.flatMap(findVariableNames)
714+
715+
override def shouldShowResult(req: Request): Boolean =
716+
!statement.mods.is(Flags.AccessFlags)
717+
718+
private def findVariableNames(tree: Tree): List[Name] = tree match {
719+
case Ident(name) if name.toString != "_" => List(name)
720+
case _ => VariableNameFinder(Nil, tree).reverse
721+
}
722+
723+
private object VariableNameFinder extends UntypedDeepFolder[List[Name]](
724+
(acc: List[Name], t: Tree) => t match {
725+
case _: BackquotedIdent => acc
726+
case Ident(name) if name.isVariableName && name.toString != "_" => name :: acc
727+
case Bind(name, _) if name.isVariableName => name :: acc
728+
case _ => acc
729+
}
730+
)
688731
}
689732

690733
private class DefHandler(defDef: DefDef) extends StatementHandler(defDef) {
@@ -836,7 +879,7 @@ class CompilingInterpreter(
836879
val stringWriter = new StringWriter()
837880
val stream = new NewLinePrintWriter(stringWriter)
838881
writer(stream)
839-
stream.close
882+
stream.close()
840883
stringWriter.toString
841884
}
842885

tests/repl/patdef.check

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
scala> val Const,x = 0
2+
Const: Int = 0
3+
x: Int = 0
4+
scala> val (Const, List(`x`, _, a), b) = (0, List(0, 1337, 1), 2)
5+
a: Int = 1
6+
b: Int = 2
7+
scala> val a@b = 0
8+
a: Int @unchecked = 0
9+
b: Int @unchecked = 0
10+
scala> :quit

0 commit comments

Comments
 (0)