Skip to content

Commit ac226f2

Browse files
committed
Implement non-local returns
Non-local returns are now implemented.
1 parent 07e24e8 commit ac226f2

File tree

5 files changed

+126
-0
lines changed

5 files changed

+126
-0
lines changed

src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class Compiler {
6868
new LazyVals,
6969
new Memoize,
7070
new LinkScala2ImplClasses,
71+
new NonLocalReturns,
7172
new CapturedVars, // capturedVars has a transformUnit: no phases should introduce local mutable vars here
7273
new Constructors,
7374
new FunctionalInterfaces,

src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
161161
def Bind(sym: TermSymbol, body: Tree)(implicit ctx: Context): Bind =
162162
ta.assignType(untpd.Bind(sym.name, body), sym)
163163

164+
/** A pattern corrsponding to `sym: tpe` */
165+
def BindTyped(sym: TermSymbol, tpe: Type)(implicit ctx: Context): Bind =
166+
Bind(sym, Typed(Underscore(tpe), TypeTree(tpe)))
167+
164168
def Alternative(trees: List[Tree])(implicit ctx: Context): Alternative =
165169
ta.assignType(untpd.Alternative(trees), trees)
166170

src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ class Definitions {
326326
lazy val Product_productArity = ProductClass.requiredMethod(nme.productArity)
327327
lazy val Product_productPrefix = ProductClass.requiredMethod(nme.productPrefix)
328328
lazy val LanguageModuleClass = ctx.requiredModule("dotty.language").moduleClass.asClass
329+
lazy val NonLocalReturnControlClass = ctx.requiredClass("scala.runtime.NonLocalReturnControl")
329330

330331
// Annotation base classes
331332
lazy val AnnotationClass = ctx.requiredClass("scala.annotation.Annotation")
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import core._
5+
import Contexts._, Symbols._, Types._, Flags._, Decorators._, StdNames._, Constants._, Phases._
6+
import TreeTransforms._
7+
import ast.Trees._
8+
import collection.mutable
9+
10+
/** Implement non-local returns using NonLocalReturnControl exceptions.
11+
*/
12+
class NonLocalReturns extends MiniPhaseTransform { thisTransformer =>
13+
override def phaseName = "nonLocalReturns"
14+
15+
import ast.tpd._
16+
17+
override def runsAfter: Set[Class[_ <: Phase]] = Set(classOf[ElimByName])
18+
19+
private def ensureConforms(tree: Tree, pt: Type)(implicit ctx: Context) =
20+
if (tree.tpe <:< pt) tree
21+
else Erasure.Boxing.adaptToType(tree, pt)
22+
23+
/** The type of a non-local return expression with given argument type */
24+
private def nonLocalReturnExceptionType(argtype: Type)(implicit ctx: Context) =
25+
defn.NonLocalReturnControlClass.typeRef.appliedTo(argtype)
26+
27+
/** A hashmap from method symbols to non-local return keys */
28+
private val nonLocalReturnKeys = mutable.Map[Symbol, TermSymbol]()
29+
30+
/** Return non-local return key for given method */
31+
private def nonLocalReturnKey(meth: Symbol)(implicit ctx: Context) =
32+
nonLocalReturnKeys.getOrElseUpdate(meth,
33+
ctx.newSymbol(
34+
meth, ctx.freshName("nonLocalReturnKey").toTermName, Synthetic, defn.ObjectType, coord = meth.pos))
35+
36+
/** Generate a non-local return throw with given return expression from given method.
37+
* I.e. for the method's non-local return key, generate:
38+
*
39+
* throw new NonLocalReturnControl(key, expr)
40+
* todo: maybe clone a pre-existing exception instead?
41+
* (but what to do about exceptions that miss their targets?)
42+
*/
43+
private def nonLocalReturnThrow(expr: Tree, meth: Symbol)(implicit ctx: Context) =
44+
Throw(
45+
New(
46+
defn.NonLocalReturnControlClass.typeRef,
47+
ref(nonLocalReturnKey(meth)) :: ensureConforms(expr, defn.ObjectType) :: Nil))
48+
49+
/** Transform (body, key) to:
50+
*
51+
* {
52+
* val key = new Object()
53+
* try {
54+
* body
55+
* } catch {
56+
* case ex: NonLocalReturnControl =>
57+
* if (ex.key().eq(key)) ex.value().asInstanceOf[T]
58+
* else throw ex
59+
* }
60+
* }
61+
*/
62+
private def nonLocalReturnTry(body: Tree, key: TermSymbol, meth: Symbol)(implicit ctx: Context) = {
63+
val keyDef = ValDef(key, New(defn.ObjectType, Nil))
64+
val nonLocalReturnControl = defn.NonLocalReturnControlClass.typeRef
65+
val ex = ctx.newSymbol(meth, nme.ex, EmptyFlags, nonLocalReturnControl, coord = body.pos)
66+
val pat = BindTyped(ex, nonLocalReturnControl)
67+
val rhs = If(
68+
ref(ex).select(nme.key).appliedToNone.select(nme.eq).appliedTo(ref(key)),
69+
ensureConforms(ref(ex).select(nme.value), meth.info.finalResultType),
70+
Throw(ref(ex)))
71+
val catches = CaseDef(pat, EmptyTree, rhs) :: Nil
72+
val tryCatch = Try(body, catches, EmptyTree)
73+
Block(keyDef :: Nil, tryCatch)
74+
}
75+
76+
def isNonLocalReturn(ret: Return)(implicit ctx: Context) =
77+
ret.from.symbol != ctx.owner.enclosingMethod || ctx.owner.is(Lazy) // Lazy needed?
78+
79+
override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree =
80+
nonLocalReturnKeys.remove(tree.symbol) match {
81+
case Some(key) => cpy.DefDef(tree)(rhs = nonLocalReturnTry(tree.rhs, key, tree.symbol))
82+
case _ => tree
83+
}
84+
85+
override def transformReturn(tree: Return)(implicit ctx: Context, info: TransformerInfo): Tree =
86+
if (isNonLocalReturn(tree)) nonLocalReturnThrow(tree.expr, tree.from.symbol).withPos(tree.pos)
87+
else tree
88+
}

tests/run/nonLocalReturns.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
object Test {
2+
3+
def foo(xs: List[Int]): Int = {
4+
xs.foreach(x => return x)
5+
0
6+
}
7+
8+
def bar(xs: List[Int]): Int = {
9+
lazy val y = if (xs.isEmpty) return -1 else xs.head
10+
y
11+
}
12+
13+
def baz(x: Int): Int =
14+
byName { return -2; 3 }
15+
16+
def byName(x: => Int): Int = x
17+
18+
def bam(): Int = { // no non-local return needed here
19+
val foo = {
20+
return -3
21+
3
22+
}
23+
foo
24+
}
25+
26+
def main(args: Array[String]) = {
27+
assert(foo(List(1, 2, 3)) == 1)
28+
assert(bar(Nil) == -1)
29+
assert(baz(3) == -2)
30+
assert(bam() == -3)
31+
}
32+
}

0 commit comments

Comments
 (0)