Skip to content

Commit 4dee036

Browse files
committed
Add support for await inside try-catch
- `await` can now be used inside the body of `try` and `finally` - using `await` inside the cases of a `catch` is illegal - provides precise error messages ("await must not be used under catch") - adds 9 tests
1 parent ffd7b96 commit 4dee036

File tree

8 files changed

+385
-87
lines changed

8 files changed

+385
-87
lines changed

src/main/scala/scala/async/AnfTransform.scala

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -121,44 +121,42 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
121121

122122
private object inline {
123123
def transformToList(tree: Tree): List[Tree] = trace("inline", tree) {
124+
def branchWithAssign(orig: Tree, varDef: ValDef) = orig match {
125+
case Block(stats, expr) => Block(stats, Assign(Ident(varDef.name), expr))
126+
case _ => Assign(Ident(varDef.name), orig)
127+
}
128+
129+
def casesWithAssign(cases: List[CaseDef], varDef: ValDef) = cases map {
130+
case cd @ CaseDef(pat, guard, orig) =>
131+
attachCopy(cd)(CaseDef(pat, guard, branchWithAssign(orig, varDef)))
132+
}
133+
124134
val stats :+ expr = anf.transformToList(tree)
125135
expr match {
136+
// if type of if-else/try/match is Unit don't introduce assignment,
137+
// but add Unit value to bring it into form expected by async transform
138+
case If(_, _, _) | Try(_, _, _) | Match(_, _) if expr.tpe =:= definitions.UnitTpe =>
139+
stats :+ expr :+ Literal(Constant(()))
140+
126141
case Apply(fun, args) if isAwait(fun) =>
127142
val valDef = defineVal(name.await, expr, tree.pos)
128143
stats :+ valDef :+ Ident(valDef.name)
129144

130145
case If(cond, thenp, elsep) =>
131-
// if type of if-else is Unit don't introduce assignment,
132-
// but add Unit value to bring it into form expected by async transform
133-
if (expr.tpe =:= definitions.UnitTpe) {
134-
stats :+ expr :+ Literal(Constant(()))
135-
} else {
136-
val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
137-
def branchWithAssign(orig: Tree) = orig match {
138-
case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr))
139-
case _ => Assign(Ident(varDef.name), orig)
140-
}
141-
val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep))
142-
stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)
143-
}
146+
val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
147+
val ifWithAssign = If(cond, branchWithAssign(thenp, varDef), branchWithAssign(elsep, varDef))
148+
stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)
149+
150+
case Try(body, catches, finalizer) =>
151+
val varDef = defineVar(name.tryRes, expr.tpe, tree.pos)
152+
val tryWithAssign = Try(branchWithAssign(body, varDef), casesWithAssign(catches, varDef), finalizer)
153+
stats :+ varDef :+ tryWithAssign :+ Ident(varDef.name)
144154

145155
case Match(scrut, cases) =>
146-
// if type of match is Unit don't introduce assignment,
147-
// but add Unit value to bring it into form expected by async transform
148-
if (expr.tpe =:= definitions.UnitTpe) {
149-
stats :+ expr :+ Literal(Constant(()))
150-
}
151-
else {
152-
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
153-
val casesWithAssign = cases map {
154-
case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) =>
155-
attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr))))
156-
case cd@CaseDef(pat, guard, body) =>
157-
attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body)))
158-
}
159-
val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign))
160-
stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
161-
}
156+
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
157+
val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign(cases, varDef)))
158+
stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
159+
162160
case _ =>
163161
stats :+ expr
164162
}
@@ -225,6 +223,11 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
225223
val stats :+ expr = inline.transformToList(rhs)
226224
stats :+ attachCopy(tree)(Assign(lhs, expr))
227225

226+
case Try(body, catches, finalizer) if containsAwait =>
227+
val stats :+ expr = inline.transformToList(body)
228+
val tryType = c.typeCheck(Try(Block(stats, expr), catches, finalizer)).tpe
229+
List(attachCopy(tree)(Try(Block(stats, expr), catches, finalizer)).setType(tryType))
230+
228231
case If(cond, thenp, elsep) if containsAwait =>
229232
val condStats :+ condExpr = inline.transformToList(cond)
230233
val thenBlock = inline.transformToBlock(thenp)

src/main/scala/scala/async/Async.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ abstract class AsyncBase {
120120
val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
121121
val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
122122
val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree)
123+
124+
// the stack of currently active exception handlers
125+
val handlers = ValDef(Modifiers(Flag.MUTABLE), name.handlers, TypeTree(typeOf[List[PartialFunction[Throwable, Unit]]]), (reify { List() }).tree)
126+
127+
// the exception that is currently in-flight or `null` otherwise
128+
val exception = ValDef(Modifiers(Flag.MUTABLE), name.exception, TypeTree(typeOf[Throwable]), Literal(Constant(null)))
129+
123130
val applyDefDef: DefDef = {
124131
val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
125132
val applyBody = asyncBlock.onCompleteHandler
@@ -132,7 +139,7 @@ abstract class AsyncBase {
132139
val applyBody = asyncBlock.onCompleteHandler
133140
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
134141
}
135-
List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
142+
List(utils.emptyConstructor, stateVar, result, execContext, handlers, exception) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
136143
}
137144
val template = {
138145
Template(List(stateMachineType), emptyValDef, body)

src/main/scala/scala/async/AsyncAnalysis.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,16 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy
7676
}
7777

7878
override def traverse(tree: Tree) {
79-
def containsAwait = tree exists isAwait
79+
def containsAwait(t: Tree) = t exists isAwait
8080
tree match {
81-
case Try(_, _, _) if containsAwait =>
82-
reportUnsupportedAwait(tree, "try/catch")
81+
case Try(_, catches, _) if catches exists containsAwait =>
82+
reportUnsupportedAwait(tree, "catch")
8383
super.traverse(tree)
84-
case Return(_) =>
84+
case Return(_) =>
8585
c.abort(tree.pos, "return is illegal within a async block")
86-
case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
86+
case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
8787
c.abort(tree.pos, "lazy vals are illegal within an async block")
88-
case _ =>
88+
case _ =>
8989
super.traverse(tree)
9090
}
9191
}

0 commit comments

Comments
 (0)