diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index c1381616..ec3a2a17 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -26,31 +26,24 @@ trait AsyncTransform { val anfTree = futureSystemOps.postAnfTransform(anfTree0) - val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(()))) - val applyDefDefDummyBody: DefDef = { val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.tryType[Any]), EmptyTree))) - DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(()))) + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), literalUnit) } // Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`. val stateMachine: ClassDef = { val body: List[Tree] = { - val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial))) val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree) val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) val apply0DefDef: DefDef = { // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creatio - DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) - } - val extraValDef: ValDef = { - // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creatio - ValDef(NoMods, newTermName("extra"), TypeTree(definitions.UnitTpe), Literal(Constant(()))) + // See SI-1247 for the the optimization that avoids creation. + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil)) } - List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef, extraValDef) + List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef) } val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit]) @@ -92,8 +85,7 @@ trait AsyncTransform { val stateMachineSpliced: Tree = spliceMethodBodies( liftedFields, stateMachine, - atMacroPos(asyncBlock.onCompleteHandler[T]), - atMacroPos(asyncBlock.resumeFunTree[T].rhs) + atMacroPos(asyncBlock.onCompleteHandler[T]) ) def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) @@ -133,10 +125,9 @@ trait AsyncTransform { * @param liftables trees of definitions that are lifted to fields of the state machine class * @param tree `ClassDef` tree of the state machine class * @param applyBody tree of onComplete handler (`apply` method) - * @param resumeBody RHS of definition tree of `resume` method * @return transformed `ClassDef` tree of the state machine class */ - def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree, resumeBody: Tree): Tree = { + def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree): Tree = { val liftedSyms = liftables.map(_.symbol).toSet val stateMachineClass = tree.symbol liftedSyms.foreach { @@ -202,12 +193,6 @@ trait AsyncTransform { (api: TypingTransformApi) => val typedTree = fixup(dd, applyBody.changeOwner(enclosingOwner, dd.symbol), api) typedTree - - case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => - (api: TypingTransformApi) => - val changed = resumeBody.changeOwner(enclosingOwner, dd.symbol) - val res = fixup(dd, changed, api) - res } result } diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 2e313475..4e521c92 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -28,7 +28,7 @@ trait ExprBuilder { def nextStates: List[Int] - def mkHandlerCaseForState: CaseDef + def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None @@ -52,8 +52,8 @@ trait ExprBuilder { def nextStates: List[Int] = List(nextState) - def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) + def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup)) override val toString: String = s"AsyncState #$state, next = $nextState" @@ -63,7 +63,7 @@ trait ExprBuilder { * a branch of an `if` or a `match`. */ final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState { - override def mkHandlerCaseForState: CaseDef = + override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = mkHandlerCase(state, stats) override val toString: String = @@ -73,45 +73,54 @@ trait ExprBuilder { /** A sequence of statements that concludes with an `await` call. The `onComplete` * handler will unconditionally transition to `nextState`. */ - final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, nextState: Int, + final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, onCompleteState: Int, nextState: Int, val awaitable: Awaitable, symLookup: SymLookup) extends AsyncState { def nextStates: List[Int] = List(nextState) - override def mkHandlerCaseForState: CaseDef = { - val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), - c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree - mkHandlerCase(state, stats :+ callOnComplete) + override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { + val fun = This(tpnme.EMPTY) + val callOnComplete = futureSystemOps.onComplete[Any, Unit](c.Expr[futureSystem.Fut[Any]](awaitable.expr), + c.Expr[futureSystem.Tryy[Any] => Unit](fun), c.Expr[futureSystem.ExecContext](Ident(name.execContext))).tree + val tryGetOrCallOnComplete = + if (futureSystemOps.continueCompletedFutureOnSameThread) + If(futureSystemOps.isCompleted(c.Expr[futureSystem.Fut[_]](awaitable.expr)).tree, + Block(ifIsFailureTree[T](futureSystemOps.getCompleted[Any](c.Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) :: Nil, literalUnit), + Block(callOnComplete :: Nil, Return(literalUnit))) + else + Block(callOnComplete :: Nil, Return(literalUnit)) + mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), tryGetOrCallOnComplete)) } + private def tryGetTree(tryReference: => Tree) = + Assign( + Ident(awaitable.resultName), + TypeApply(Select(futureSystemOps.tryyGet[Any](c.Expr[futureSystem.Tryy[Any]](tryReference)).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + ) + + /* if (tr.isFailure) + * result.complete(tr.asInstanceOf[Try[T]]) + * else { + * = tr.get.asInstanceOf[] + * + * + * } + */ + def ifIsFailureTree[T: WeakTypeTag](tryReference: => Tree) = + If(futureSystemOps.tryyIsFailure(c.Expr[futureSystem.Tryy[T]](tryReference)).tree, + Block(futureSystemOps.completeProm[T]( + c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + c.Expr[futureSystem.Tryy[T]]( + TypeApply(Select(tryReference, newTermName("asInstanceOf")), + List(TypeTree(futureSystemOps.tryType[T]))))).tree :: Nil, + Return(literalUnit)), + Block(List(tryGetTree(tryReference)), mkStateTree(nextState, symLookup)) + ) + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { - val tryGetTree = - Assign( - Ident(awaitable.resultName), - TypeApply(Select(futureSystemOps.tryyGet[T](c.Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) - ) - - /* if (tr.isFailure) - * result.complete(tr.asInstanceOf[Try[T]]) - * else { - * = tr.get.asInstanceOf[] - * - * - * } - */ - val ifIsFailureTree = - If(futureSystemOps.tryyIsFailure(c.Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, - futureSystemOps.completeProm[T]( - c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), - c.Expr[futureSystem.Tryy[T]]( - TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), - List(TypeTree(futureSystemOps.tryType[T]))))).tree, - Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) - ) - - Some(mkHandlerCase(state, List(ifIsFailureTree))) + Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree[T](Ident(symLookup.applyTrParam))))) } override val toString: String = @@ -147,9 +156,10 @@ trait ExprBuilder { } def resultWithAwait(awaitable: Awaitable, + onCompleteState: Int, nextState: Int): AsyncState = { val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) + new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { @@ -158,7 +168,7 @@ trait ExprBuilder { } def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { - def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + def mkBranch(state: Int) = mkStateTree(state, symLookup) this += If(condTree, mkBranch(thenState), mkBranch(elseState)) new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState)) } @@ -178,7 +188,7 @@ trait ExprBuilder { val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { case CaseDef(pat, guard, rhs) => val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) - CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) + CaseDef(pat, guard, Block(bindAssigns, mkStateTree(caseStates(num), symLookup))) } // 2. insert changed match tree at the end of the current state this += Match(scrutTree, newCases) @@ -186,7 +196,7 @@ trait ExprBuilder { } def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { - this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) + this += mkStateTree(startLabelState, symLookup) new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState)) } @@ -227,9 +237,10 @@ trait ExprBuilder { for (stat <- stats) stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + val onCompleteState = nextState() val afterAwaitState = nextState() val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) - asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await + asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await currState = afterAwaitState stateBuilder = new AsyncStateBuilder(currState, symLookup) @@ -297,8 +308,6 @@ trait ExprBuilder { def asyncStates: List[AsyncState] def onCompleteHandler[T: WeakTypeTag]: Tree - - def resumeFunTree[T: WeakTypeTag]: DefDef } case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { @@ -331,13 +340,13 @@ trait ExprBuilder { val lastStateBody = c.Expr[T](lastState.body) val rhs = futureSystemOps.completeProm( c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryySuccess[T](lastStateBody)) - mkHandlerCase(lastState.state, rhs.tree) + mkHandlerCase(lastState.state, Block(rhs.tree, Return(literalUnit))) } asyncStates.toList match { case s :: Nil => List(caseForLastState) case _ => - val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState[T] initCases :+ caseForLastState } } @@ -363,18 +372,23 @@ trait ExprBuilder { * } * } */ - def resumeFunTree[T: WeakTypeTag]: DefDef = - DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), + private def resumeFunTree[T: WeakTypeTag]: Tree = Try( - Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ), List( CaseDef( Bind(name.t, Ident(nme.WILDCARD)), Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), { val t = c.Expr[Throwable](Ident(name.t)) - futureSystemOps.completeProm[T]( + val complete = futureSystemOps.completeProm[T]( c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree - })), EmptyTree)) + Block(complete :: Nil, Return(literalUnit)) + })), EmptyTree) + + def forever(t: Tree): Tree = { + val labelName = name.fresh("while$") + LabelDef(labelName, Nil, Block(t :: Nil, Apply(Ident(labelName), Nil))) + } /** * Builds a `match` expression used as an onComplete handler. @@ -388,8 +402,12 @@ trait ExprBuilder { * resume() * } */ - def onCompleteHandler[T: WeakTypeTag]: Tree = - Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) + def onCompleteHandler[T: WeakTypeTag]: Tree = { + val onCompletes = initStates.flatMap(_.mkOnCompleteHandler[T]).toList + forever { + Block(resumeFunTree :: Nil, literalUnit) + } + } } } @@ -400,9 +418,6 @@ trait ExprBuilder { case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) - private def mkResumeApply(symLookup: SymLookup) = - Apply(symLookup.memberRef(name.resume), Nil) - private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) @@ -412,5 +427,7 @@ trait ExprBuilder { private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = CaseDef(Literal(Constant(num)), EmptyTree, rhs) - private def literalUnit = Literal(Constant(())) + def literalUnit = Literal(Constant(())) + + def literalNull = Literal(Constant(null)) } diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala index 1b1ffc3e..6fccfdd3 100644 --- a/src/main/scala/scala/async/internal/FutureSystem.scala +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -47,6 +47,12 @@ trait FutureSystem { def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[Tryy[A] => U], execContext: Expr[ExecContext]): Expr[Unit] + def continueCompletedFutureOnSameThread = false + def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = + throw new UnsupportedOperationException("isCompleted not supported by this FutureSystem") + def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = + throw new UnsupportedOperationException("getCompleted not supported by this FutureSystem") + /** Complete a promise with a value */ def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit] @@ -100,6 +106,15 @@ object ScalaConcurrentFutureSystem extends FutureSystem { future.splice.onComplete(fun.splice)(execContext.splice) } + override def continueCompletedFutureOnSameThread: Boolean = true + + override def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = reify { + future.splice.isCompleted + } + override def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = reify { + future.splice.value.get + } + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { prom.splice.complete(value.splice) c.Expr[Unit](Literal(Constant(()))).splice diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala index 8f0d5180..55e7a51c 100644 --- a/src/main/scala/scala/async/internal/StateAssigner.scala +++ b/src/main/scala/scala/async/internal/StateAssigner.scala @@ -5,10 +5,12 @@ package scala.async.internal private[async] final class StateAssigner { - private var current = -1 + private var current = StateAssigner.Initial - def nextState(): Int = { - current += 1 - current - } + def nextState(): Int = + try current finally current += 1 } + +object StateAssigner { + final val Initial = 0 +} \ No newline at end of file diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 5569ade4..bd7093f2 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -119,6 +119,7 @@ private[async] trait TransformUtils { private def isByName(fun: Tree): ((Int, Int) => Boolean) = { if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true + else if (fun.tpe == null) (x, y) => false else { val paramss = fun.tpe.paramss val byNamess = paramss.map(_.map(_.asTerm.isByNameParam)) @@ -140,10 +141,6 @@ private[async] trait TransformUtils { self.splice.contains(elem.splice) } - def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { - self.splice.apply(arg.splice) - } - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { self.splice == other.splice } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 0d5e68ac..d6c619f8 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -36,7 +36,7 @@ class TreeInterrogation { functions.size mustBe 1 val varDefs = tree1.collect { - case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name + case vd @ ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) && vd.symbol.owner.isClass => name } varDefs.map(_.decoded.trim).toSet.toList.sorted mustStartWith (List("await$macro$", "await$macro$", "state")) @@ -49,7 +49,7 @@ class TreeInterrogation { && !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name } }.flatten - defDefs.map(_.decoded.trim).toSet.toList.sorted mustStartWith (List("", "apply", "foo$macro$", "resume")) + defDefs.map(_.decoded.trim).toList mustStartWith (List("foo$macro$", "", "apply", "apply")) } } diff --git a/src/test/scala/scala/async/run/SyncOptimizationSpec.scala b/src/test/scala/scala/async/run/SyncOptimizationSpec.scala new file mode 100644 index 00000000..dd649f46 --- /dev/null +++ b/src/test/scala/scala/async/run/SyncOptimizationSpec.scala @@ -0,0 +1,28 @@ +package scala.async.run + +import org.junit.Test +import scala.async.Async._ +import scala.concurrent._ +import scala.concurrent.duration._ +import ExecutionContext.Implicits._ + +class SyncOptimizationSpec { + @Test + def awaitOnCompletedFutureRunsOnSameThread: Unit = { + + def stackDepth = Thread.currentThread().getStackTrace.size + + val future = async { + val thread1 = Thread.currentThread + val stackDepth1 = stackDepth + + val f = await(Future.successful(1)) + val thread2 = Thread.currentThread + val stackDepth2 = stackDepth + assert(thread1 == thread2) + assert(stackDepth1 == stackDepth2) + } + Await.result(future, 10.seconds) + } + +} diff --git a/src/test/scala/scala/async/run/futures/FutureSpec.scala b/src/test/scala/scala/async/run/futures/FutureSpec.scala index 1761db50..362303e9 100644 --- a/src/test/scala/scala/async/run/futures/FutureSpec.scala +++ b/src/test/scala/scala/async/run/futures/FutureSpec.scala @@ -134,6 +134,13 @@ class FutureSpec { Await.result(future1, defaultTimeout) mustBe ("10-14") intercept[NoSuchElementException] { Await.result(future2, defaultTimeout) } } + + @Test def mini() { + val future4 = async { + await(Future.successful(0)).toString + } + Await.result(future4, defaultTimeout) + } @Test def `recover from exceptions`() { val future1 = Future(5) @@ -531,7 +538,6 @@ class FutureSpec { val f = async { await(future(5)) / 0 } Await.ready(f, defaultTimeout).value.get.toString mustBe expected.toString } - } diff --git a/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala b/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala new file mode 100644 index 00000000..2dc9b92b --- /dev/null +++ b/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2012-2014 Typesafe Inc. + */ + +package scala.async +package run +package stackoverflow + +import org.junit.Test +import scala.async.internal.AsyncId + + +class StackOverflowSpec { + + @Test + def stackSafety() { + import AsyncId._ + async { + var i = 100000000 + while (i > 0) { + if (false) { + await(()) + } + i -= 1 + } + } + } +}