Skip to content

Commit 4fc5463

Browse files
committed
Fix crashers in do/while and while(await(..))
The new tree shapes handled for do/while look like: // type checked async({ val b = false; doWhile$1(){ await(()); if (b) doWhile$1() else () }; () }) We had to change ExprBuilder to create states for the if/else that concludes the doWhile body, and also loosen the assertion that the label jump must be the last thing we see. We also have to look for more than just `containsAwait` when deciding whether an `If` needs to be transformed into states; it might also contain a jump to the enclosing label that is on the other side of an `await`, and hence needs to be a state transition instead.
1 parent 23f94c2 commit 4fc5463

File tree

5 files changed

+73
-34
lines changed

5 files changed

+73
-34
lines changed

src/main/scala/scala/async/internal/AnfTransform.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,14 @@ private[async] trait AnfTransform {
162162
def _transformToList(tree: Tree): List[Tree] = trace(tree) {
163163
val containsAwait = tree exists isAwait
164164
if (!containsAwait) {
165-
List(tree)
165+
tree match {
166+
case Block(stats, expr) =>
167+
// avoids nested block in `while(await(false)) ...`.
168+
// TODO I think `containsAwait` really should return true if the code contains a label jump to an enclosing
169+
// while/doWhile and there is an await *anywhere* inside that construct.
170+
stats :+ expr
171+
case _ => List(tree)
172+
}
166173
} else tree match {
167174
case Select(qual, sel) =>
168175
val stats :+ expr = linearize.transformToList(qual)

src/main/scala/scala/async/internal/ExprBuilder.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ trait ExprBuilder {
127127
private var nextJumpState: Option[Int] = None
128128

129129
def +=(stat: Tree): this.type = {
130-
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
130+
stat match {
131+
case Literal(Constant(())) => // This case occurs in do/while
132+
case _ =>
133+
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
134+
}
131135
def addStat() = stats += stat
132136
stat match {
133137
case Apply(fun, Nil) =>
@@ -228,7 +232,7 @@ trait ExprBuilder {
228232
currState = afterAwaitState
229233
stateBuilder = new AsyncStateBuilder(currState, symLookup)
230234

231-
case If(cond, thenp, elsep) if stat exists isAwait =>
235+
case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) =>
232236
checkForUnsupportedAwait(cond)
233237

234238
val thenStartState = nextState()

src/main/scala/scala/async/internal/TransformUtils.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ private[async] trait TransformUtils {
9696
treeInfo.isExprSafeToInline(tree)
9797
}
9898

99+
// `while(await(x))` ... or `do { await(x); ... } while(...)` contain an `If` that loops;
100+
// we must break that `If` into states so that it convert the label jump into a state machine
101+
// transition
102+
final def containsForiegnLabelJump(t: Tree): Boolean = {
103+
val labelDefs = t.collect {
104+
case ld: LabelDef => ld.symbol
105+
}.toSet
106+
t.exists {
107+
case rt: RefTree => !(labelDefs contains rt.symbol)
108+
case _ => false
109+
}
110+
}
111+
99112
/** Map a list of arguments to:
100113
* - A list of argument Trees
101114
* - A list of auxillary results.

src/test/scala/scala/async/TreeInterrogation.scala

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -66,43 +66,18 @@ object TreeInterrogation extends App {
6666
withDebug {
6767
val cm = reflect.runtime.currentMirror
6868
val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid")
69-
import scala.async.internal.AsyncTestLV._
69+
import scala.async.internal.AsyncId._
7070
val tree = tb.parse(
7171
"""
72-
| import scala.async.internal.AsyncTestLV._
73-
| import scala.async.internal.AsyncTestLV
74-
|
75-
| case class MCell[T](var v: T)
76-
| val f = async { MCell(1) }
77-
|
78-
| def m1(x: MCell[Int], y: Int): Int =
79-
| async { x.v + y }
80-
| case class Cell[T](v: T)
81-
|
72+
| import scala.async.internal.AsyncId._
8273
| async {
83-
| // state #1
84-
| val a: MCell[Int] = await(f) // await$13$1
85-
| // state #2
86-
| var y = MCell(0)
87-
|
88-
| while (a.v < 10) {
89-
| // state #4
90-
| a.v = a.v + 1
91-
| y = MCell(await(a).v + 1) // await$14$1
92-
| // state #7
74+
| var b = true
75+
| while(await(b)) {
76+
| b = false
9377
| }
94-
|
95-
| // state #3
96-
| assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1"))
97-
|
98-
| val b = await(m1(a, y.v)) // await$15$1
99-
| // state #8
100-
| assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))))
101-
| assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11))))
102-
| b
78+
| await(b)
10379
| }
10480
|
105-
|
10681
| """.stripMargin)
10782
println(tree)
10883
val tree1 = tb.typeCheck(tree.duplicate)

src/test/scala/scala/async/run/ifelse0/WhileSpec.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,44 @@ class WhileSpec {
7676
}
7777
result mustBe ()
7878
}
79+
80+
@Test def doWhile() {
81+
import AsyncId._
82+
val result = async {
83+
var b = 0
84+
var x = ""
85+
await(do {
86+
x += "1"
87+
x += await("2")
88+
x += "3"
89+
b += await(1)
90+
} while (b < 2))
91+
await(x)
92+
}
93+
result mustBe "123123"
94+
}
95+
96+
@Test def whileAwaitCondition() {
97+
import AsyncId._
98+
val result = async {
99+
var b = true
100+
while(await(b)) {
101+
b = false
102+
}
103+
await(b)
104+
}
105+
result mustBe false
106+
}
107+
108+
@Test def doWhileAwaitCondition() {
109+
import AsyncId._
110+
val result = async {
111+
var b = true
112+
do {
113+
b = false
114+
} while(await(b))
115+
b
116+
}
117+
result mustBe false
118+
}
79119
}

0 commit comments

Comments
 (0)