Skip to content

Determistic output from the async macro #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,49 @@ pomExtra := (
</developers>
)
OsgiKeys.exportPackage := Seq(s"scala.async.*;version=${version.value}")

commands += testDeterminism

def testDeterminism = Command.command("testDeterminism") { state =>
val extracted = Project.extract(state)
println("Running test:clean")
val (state1, _) = extracted.runTask(clean in Test in LocalRootProject, state)
println("Running test:compile")
val (state2, _) = extracted.runTask(compile in Test in LocalRootProject, state1)
val testClasses = extracted.get(classDirectory in Test)
val baseline: File = testClasses.getParentFile / (testClasses.getName + "-baseline")
baseline.mkdirs()
IO.copyDirectory(testClasses, baseline, overwrite = true)
IO.delete(testClasses)
println("Running test:compile")
val (state3, _) = extracted.runTask(compile in Test in LocalRootProject, state2)

import java.nio.file.FileVisitResult
import java.nio.file.{Files, Path}
import java.nio.file.SimpleFileVisitor
import java.nio.file.attribute.BasicFileAttributes
import java.util

def checkSameFileContents(one: Path, other: Path): Unit = {
Files.walkFileTree(one, new SimpleFileVisitor[Path]() {
override def visitFile(file: Path, attrs: BasicFileAttributes): FileVisitResult = {
val result: FileVisitResult = super.visitFile(file, attrs)
// get the relative file name from path "one"
val relativize: Path = one.relativize(file)
// construct the path for the counterpart file in "other"
val fileInOther: Path = other.resolve(relativize)
val otherBytes: Array[Byte] = Files.readAllBytes(fileInOther)
val thisBytes: Array[Byte] = Files.readAllBytes(file)
if (!(util.Arrays.equals(otherBytes, thisBytes))) {
throw new AssertionError(file + " is not equal to " + fileInOther)
}
return result
}
})
}
println("Comparing: " + baseline.toPath + " and " + testClasses.toPath)
checkSameFileContents(baseline.toPath, testClasses.toPath)
checkSameFileContents(testClasses.toPath, baseline.toPath)

state3
}
35 changes: 18 additions & 17 deletions src/main/scala/scala/async/internal/Lifter.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package scala.async.internal

import scala.collection.mutable

trait Lifter {
self: AsyncMacro =>
import c.universe._
Expand Down Expand Up @@ -37,7 +39,7 @@ trait Lifter {
}


val defs: Map[Tree, Int] = {
val defs: mutable.LinkedHashMap[Tree, Int] = {
/** Collect the DefTrees directly enclosed within `t` that have the same owner */
def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match {
case ld: LabelDef => Nil
Expand All @@ -48,33 +50,33 @@ trait Lifter {
companionship.record(childDefs)
childDefs
}
asyncStates.flatMap {
mutable.LinkedHashMap(asyncStates.flatMap {
asyncState =>
val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*))
defs.map((_, asyncState.state))
}.toMap
}: _*)
}

// In which block are these symbols defined?
val symToDefiningState: Map[Symbol, Int] = defs.map {
val symToDefiningState: mutable.LinkedHashMap[Symbol, Int] = defs.map {
case (k, v) => (k.symbol, v)
}

// The definitions trees
val symToTree: Map[Symbol, Tree] = defs.map {
val symToTree: mutable.LinkedHashMap[Symbol, Tree] = defs.map {
case (k, v) => (k.symbol, k)
}

// The direct references of each definition tree
val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map {
case tree => (tree.symbol, tree.collect {
val defSymToReferenced: mutable.LinkedHashMap[Symbol, List[Symbol]] = defs.map {
case (tree, _) => (tree.symbol, tree.collect {
case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol
})
}.toMap
}

// The direct references of each block, excluding references of `DefTree`-s which
// are already accounted for.
val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = {
val stateIdToDirectlyReferenced: mutable.LinkedHashMap[Int, List[Symbol]] = {
val refs: List[(Int, Symbol)] = asyncStates.flatMap(
asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect {
case rt: RefTree
Expand All @@ -84,8 +86,8 @@ trait Lifter {
toMultiMap(refs)
}

def liftableSyms: Set[Symbol] = {
val liftableMutableSet = collection.mutable.Set[Symbol]()
def liftableSyms: mutable.LinkedHashSet[Symbol] = {
val liftableMutableSet = mutable.LinkedHashSet[Symbol]()
def markForLift(sym: Symbol): Unit = {
if (!liftableMutableSet(sym)) {
liftableMutableSet += sym
Expand All @@ -97,19 +99,19 @@ trait Lifter {
}
}
// Start things with DefTrees directly referenced from statements from other states...
val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap {
val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.iterator.flatMap {
case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i)
}
}.toList
// .. and likewise for DefTrees directly referenced by other DefTrees from other states
val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap {
case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee))
}
// Mark these for lifting, which will follow transitive references.
(liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift)
liftableMutableSet.toSet
liftableMutableSet
}

val lifted = liftableSyms.map(symToTree).toList.map {
liftableSyms.iterator.map(symToTree).map {
t =>
val sym = t.symbol
val treeLifted = t match {
Expand Down Expand Up @@ -147,7 +149,6 @@ trait Lifter {
treeCopy.TypeDef(td, Modifiers(sym.flags), sym.name, tparams, rhs)
}
atPos(t.pos)(treeLifted)
}
lifted
}.toList
}
}
16 changes: 9 additions & 7 deletions src/main/scala/scala/async/internal/LiveVariables.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package scala.async.internal

import scala.collection.mutable

import java.util
import java.util.function.{IntConsumer, IntPredicate}

Expand All @@ -19,12 +21,12 @@ trait LiveVariables {
* @return a map mapping a state to the fields that should be nulled out
* upon resuming that state
*/
def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = {
def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): mutable.LinkedHashMap[Int, List[Tree]] = {
// live variables analysis:
// the result map indicates in which states a given field should be nulled out
val liveVarsMap: Map[Tree, StateSet] = liveVars(asyncStates, liftables)
val liveVarsMap: mutable.LinkedHashMap[Tree, StateSet] = liveVars(asyncStates, liftables)

var assignsOf = Map[Int, List[Tree]]()
var assignsOf = mutable.LinkedHashMap[Int, List[Tree]]()

for ((fld, where) <- liveVarsMap) {
where.foreach { new IntConsumer { def accept(state: Int): Unit = {
Expand Down Expand Up @@ -54,7 +56,7 @@ trait LiveVariables {
* @param liftables the lifted fields
* @return a map which indicates for a given field (the key) the states in which it should be nulled out
*/
def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, StateSet] = {
def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): mutable.LinkedHashMap[Tree, StateSet] = {
val liftedSyms: Set[Symbol] = // include only vars
liftables.iterator.filter {
case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE)
Expand Down Expand Up @@ -262,15 +264,15 @@ trait LiveVariables {
result
}

val lastUsages: Map[Tree, StateSet] =
liftables.iterator.map(fld => fld -> lastUsagesOf(fld, finalState)).toMap
val lastUsages: mutable.LinkedHashMap[Tree, StateSet] =
mutable.LinkedHashMap(liftables.map(fld => fld -> lastUsagesOf(fld, finalState)): _*)

if(AsyncUtils.verbose) {
for ((fld, lastStates) <- lastUsages)
AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.iterator.mkString(", ")}")
}

val nullOutAt: Map[Tree, StateSet] =
val nullOutAt: mutable.LinkedHashMap[Tree, StateSet] =
for ((fld, lastStates) <- lastUsages) yield {
var result = new StateSet
lastStates.foreach(new IntConsumer { def accept(s: Int): Unit = {
Expand Down
13 changes: 11 additions & 2 deletions src/main/scala/scala/async/internal/TransformUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package scala.async.internal
import scala.reflect.macros.Context
import reflect.ClassTag
import scala.collection.immutable.ListMap
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/**
* Utilities used in both `ExprBuilder` and `AnfTransform`.
Expand Down Expand Up @@ -303,8 +305,15 @@ private[async] trait TransformUtils {
})
}

def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] =
as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap
def toMultiMap[A, B](abs: Iterable[(A, B)]): mutable.LinkedHashMap[A, List[B]] = {
// LinkedHashMap for stable order of results.
val result = new mutable.LinkedHashMap[A, ListBuffer[B]]()
for ((a, b) <- abs) {
val buffer = result.getOrElseUpdate(a, new ListBuffer[B])
buffer += b
}
result.map { case (a, b) => (a, b.toList) }
}

// Attributed version of `TreeGen#mkCastPreservingAnnotations`
def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = {
Expand Down