Skip to content

Commit 1333b38

Browse files
committed
Avoid unbounded stack consumption for synchronous control flow
Previously, as sequence of state transitions that did not pass through an asynchrous boundary incurred stack frames. The trivial loop in the enclosed test case would then overflow the stack. This commit merges the `resume` and `apply(tr: Try[Any])` methods into a `apply`. It changes the body of this method to be an infinite loop with returns at the terminal points in the state machine (or at a terminal failure.) To allow merging of these previously separate matches, states that contain an await are now allocated two state ids: one for the setup code that calls `onComplete`, and one for the code in the continuation that records the result and advances the state machine. Fixes #93
1 parent 61b4c18 commit 1333b38

File tree

7 files changed

+87
-54
lines changed

7 files changed

+87
-54
lines changed

src/main/scala/scala/async/internal/AsyncTransform.scala

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,29 @@ trait AsyncTransform {
2424

2525
val anfTree = futureSystemOps.postAnfTransform(anfTree0)
2626

27-
val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(())))
28-
2927
val applyDefDefDummyBody: DefDef = {
3028
val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.tryType[Any]), EmptyTree)))
31-
DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(())))
29+
DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), literalUnit)
3230
}
3331

3432
// Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`.
3533
val stateMachine: ClassDef = {
3634
val body: List[Tree] = {
37-
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
35+
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial)))
3836
val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
3937
val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
4038

4139
val apply0DefDef: DefDef = {
4240
// We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
4341
// See SI-1247 for the the optimization that avoids creatio
44-
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
42+
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil))
4543
}
4644
val extraValDef: ValDef = {
4745
// We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
4846
// See SI-1247 for the the optimization that avoids creatio
49-
ValDef(NoMods, newTermName("extra"), TypeTree(definitions.UnitTpe), Literal(Constant(())))
47+
ValDef(NoMods, newTermName("extra"), TypeTree(definitions.UnitTpe), literalUnit)
5048
}
51-
List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef, extraValDef)
49+
List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef, extraValDef)
5250
}
5351

5452
val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit])
@@ -90,8 +88,7 @@ trait AsyncTransform {
9088
val stateMachineSpliced: Tree = spliceMethodBodies(
9189
liftedFields,
9290
stateMachine,
93-
atMacroPos(asyncBlock.onCompleteHandler[T]),
94-
atMacroPos(asyncBlock.resumeFunTree[T].rhs)
91+
atMacroPos(asyncBlock.onCompleteHandler[T])
9592
)
9693

9794
def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
@@ -131,10 +128,9 @@ trait AsyncTransform {
131128
* @param liftables trees of definitions that are lifted to fields of the state machine class
132129
* @param tree `ClassDef` tree of the state machine class
133130
* @param applyBody tree of onComplete handler (`apply` method)
134-
* @param resumeBody RHS of definition tree of `resume` method
135131
* @return transformed `ClassDef` tree of the state machine class
136132
*/
137-
def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree, resumeBody: Tree): Tree = {
133+
def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree): Tree = {
138134
val liftedSyms = liftables.map(_.symbol).toSet
139135
val stateMachineClass = tree.symbol
140136
liftedSyms.foreach {
@@ -211,12 +207,6 @@ trait AsyncTransform {
211207
(ctx: analyzer.Context) =>
212208
val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx)
213209
typedTree
214-
215-
case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass =>
216-
(ctx: analyzer.Context) =>
217-
val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol)
218-
val res = fixup(dd, changed, ctx)
219-
res
220210
}
221211
result
222212
}

src/main/scala/scala/async/internal/ExprBuilder.scala

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ trait ExprBuilder {
5252
List(nextState)
5353

5454
def mkHandlerCaseForState: CaseDef =
55-
mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup))
55+
mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup))
5656

5757
override val toString: String =
5858
s"AsyncState #$state, next = $nextState"
@@ -72,7 +72,7 @@ trait ExprBuilder {
7272
/** A sequence of statements that concludes with an `await` call. The `onComplete`
7373
* handler will unconditionally transition to `nextState`.
7474
*/
75-
final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, nextState: Int,
75+
final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, onCompleteState: Int, nextState: Int,
7676
val awaitable: Awaitable, symLookup: SymLookup)
7777
extends AsyncState {
7878

@@ -82,7 +82,7 @@ trait ExprBuilder {
8282
override def mkHandlerCaseForState: CaseDef = {
8383
val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr),
8484
Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree
85-
mkHandlerCase(state, stats :+ callOnComplete)
85+
mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), callOnComplete, Return(literalUnit)))
8686
}
8787

8888
override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = {
@@ -102,15 +102,16 @@ trait ExprBuilder {
102102
*/
103103
val ifIsFailureTree =
104104
If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree,
105-
futureSystemOps.completeProm[T](
105+
Block(futureSystemOps.completeProm[T](
106106
Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
107107
Expr[futureSystem.Tryy[T]](
108108
TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")),
109-
List(TypeTree(futureSystemOps.tryType[T]))))).tree,
110-
Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup))
109+
List(TypeTree(futureSystemOps.tryType[T]))))).tree :: Nil,
110+
Return(literalUnit)),
111+
Block(List(tryGetTree), mkStateTree(nextState, symLookup))
111112
)
112113

113-
Some(mkHandlerCase(state, List(ifIsFailureTree)))
114+
Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree)))
114115
}
115116

116117
override val toString: String =
@@ -146,9 +147,10 @@ trait ExprBuilder {
146147
}
147148

148149
def resultWithAwait(awaitable: Awaitable,
150+
onCompleteState: Int,
149151
nextState: Int): AsyncState = {
150152
val effectiveNextState = nextJumpState.getOrElse(nextState)
151-
new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup)
153+
new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup)
152154
}
153155

154156
def resultSimple(nextState: Int): AsyncState = {
@@ -157,7 +159,7 @@ trait ExprBuilder {
157159
}
158160

159161
def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
160-
def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup))
162+
def mkBranch(state: Int) = mkStateTree(state, symLookup)
161163
this += If(condTree, mkBranch(thenState), mkBranch(elseState))
162164
new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState))
163165
}
@@ -177,15 +179,15 @@ trait ExprBuilder {
177179
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
178180
case CaseDef(pat, guard, rhs) =>
179181
val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal)
180-
CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup)))
182+
CaseDef(pat, guard, Block(bindAssigns, mkStateTree(caseStates(num), symLookup)))
181183
}
182184
// 2. insert changed match tree at the end of the current state
183185
this += Match(scrutTree, newCases)
184186
new AsyncStateWithoutAwait(stats.toList, state, caseStates)
185187
}
186188

187189
def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = {
188-
this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup))
190+
this += mkStateTree(startLabelState, symLookup)
189191
new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState))
190192
}
191193

@@ -226,9 +228,10 @@ trait ExprBuilder {
226228
for (stat <- stats) stat match {
227229
// the val name = await(..) pattern
228230
case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
231+
val onCompleteState = nextState()
229232
val afterAwaitState = nextState()
230233
val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd)
231-
asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await
234+
asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await
232235
currState = afterAwaitState
233236
stateBuilder = new AsyncStateBuilder(currState, symLookup)
234237

@@ -296,8 +299,6 @@ trait ExprBuilder {
296299
def asyncStates: List[AsyncState]
297300

298301
def onCompleteHandler[T: WeakTypeTag]: Tree
299-
300-
def resumeFunTree[T: WeakTypeTag]: DefDef
301302
}
302303

303304
case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) {
@@ -330,7 +331,7 @@ trait ExprBuilder {
330331
val lastStateBody = Expr[T](lastState.body)
331332
val rhs = futureSystemOps.completeProm(
332333
Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryySuccess[T](lastStateBody))
333-
mkHandlerCase(lastState.state, rhs.tree)
334+
mkHandlerCase(lastState.state, Block(rhs.tree, Return(literalUnit)))
334335
}
335336
asyncStates.toList match {
336337
case s :: Nil =>
@@ -362,18 +363,23 @@ trait ExprBuilder {
362363
* }
363364
* }
364365
*/
365-
def resumeFunTree[T: WeakTypeTag]: DefDef =
366-
DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass),
366+
private def resumeFunTree[T: WeakTypeTag]: Tree =
367367
Try(
368-
Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]),
368+
Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ),
369369
List(
370370
CaseDef(
371371
Bind(name.t, Ident(nme.WILDCARD)),
372372
Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), {
373373
val t = Expr[Throwable](Ident(name.t))
374-
futureSystemOps.completeProm[T](
374+
val complete = futureSystemOps.completeProm[T](
375375
Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree
376-
})), EmptyTree))
376+
Block(complete :: Nil, Return(literalUnit))
377+
})), EmptyTree)
378+
379+
def forever(t: Tree): Tree = {
380+
val labelName = name.fresh("while$")
381+
LabelDef(labelName, Nil, Block(t :: Nil, Apply(Ident(labelName), Nil)))
382+
}
377383

378384
/**
379385
* Builds a `match` expression used as an onComplete handler.
@@ -387,8 +393,12 @@ trait ExprBuilder {
387393
* resume()
388394
* }
389395
*/
390-
def onCompleteHandler[T: WeakTypeTag]: Tree =
391-
Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
396+
def onCompleteHandler[T: WeakTypeTag]: Tree = {
397+
val onCompletes = initStates.flatMap(_.mkOnCompleteHandler[T]).toList
398+
forever {
399+
Block(resumeFunTree :: Nil, literalUnit)
400+
}
401+
}
392402
}
393403
}
394404

@@ -399,9 +409,6 @@ trait ExprBuilder {
399409

400410
case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef)
401411

402-
private def mkResumeApply(symLookup: SymLookup) =
403-
Apply(symLookup.memberRef(name.resume), Nil)
404-
405412
private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree =
406413
Assign(symLookup.memberRef(name.state), Literal(Constant(nextState)))
407414

@@ -411,5 +418,7 @@ trait ExprBuilder {
411418
private def mkHandlerCase(num: Int, rhs: Tree): CaseDef =
412419
CaseDef(Literal(Constant(num)), EmptyTree, rhs)
413420

414-
private def literalUnit = Literal(Constant(()))
421+
def literalUnit = Literal(Constant(()))
422+
423+
def literalNull = Literal(Constant(null))
415424
}

src/main/scala/scala/async/internal/StateAssigner.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
package scala.async.internal
66

77
private[async] final class StateAssigner {
8-
private var current = -1
8+
private var current = StateAssigner.Initial
99

10-
def nextState(): Int = {
11-
current += 1
12-
current
13-
}
10+
def nextState(): Int =
11+
try current finally current += 1
1412
}
13+
14+
object StateAssigner {
15+
final val Initial = 0
16+
}

src/main/scala/scala/async/internal/TransformUtils.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ private[async] trait TransformUtils {
4949

5050
private def isByName(fun: Tree): ((Int, Int) => Boolean) = {
5151
if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true
52+
else if (fun.tpe == null) (x, y) => false
5253
else {
5354
val paramss = fun.tpe.paramss
5455
val byNamess = paramss.map(_.map(_.isByNameParam))
@@ -72,10 +73,6 @@ private[async] trait TransformUtils {
7273
self.splice.contains(elem.splice)
7374
}
7475

75-
def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
76-
self.splice.apply(arg.splice)
77-
}
78-
7976
def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
8077
self.splice == other.splice
8178
}

src/test/scala/scala/async/TreeInterrogation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class TreeInterrogation {
3636
functions.size mustBe 1
3737

3838
val varDefs = tree1.collect {
39-
case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name
39+
case vd @ ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) && vd.symbol.owner.isClass => name
4040
}
4141
varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1$1", "await$2$1"))
4242

@@ -49,7 +49,7 @@ class TreeInterrogation {
4949
&& !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name
5050
}
5151
}.flatten
52-
defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume", "<init>"))
52+
defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "<init>"))
5353
}
5454
}
5555

src/test/scala/scala/async/run/futures/FutureSpec.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ class FutureSpec {
134134
Await.result(future1, defaultTimeout) mustBe ("10-14")
135135
intercept[NoSuchElementException] { Await.result(future2, defaultTimeout) }
136136
}
137+
138+
@Test def mini() {
139+
val future4 = async {
140+
await(Future.successful(0)).toString
141+
}
142+
Await.result(future4, defaultTimeout)
143+
}
137144

138145
@Test def `recover from exceptions`() {
139146
val future1 = Future(5)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (C) 2012-2014 Typesafe Inc. <http://www.typesafe.com>
3+
*/
4+
5+
package scala.async
6+
package run
7+
package stackoverflow
8+
9+
import org.junit.Test
10+
import scala.async.internal.AsyncId
11+
12+
13+
class StackOverflowSpec {
14+
15+
@Test
16+
def stackSafety() {
17+
import AsyncId._
18+
async {
19+
var i = 100000000
20+
while (i > 0) {
21+
if (false) {
22+
await(())
23+
}
24+
i -= 1
25+
}
26+
}
27+
}
28+
}

0 commit comments

Comments
 (0)