From cf12129e286a6056d97068094cbf86491c432395 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 12 Nov 2014 11:21:00 +0100 Subject: [PATCH 1/6] Added methods to prepare-for and transform a complete compilation unit tree. Should replace destructive inits. --- .../tools/dotc/transform/TreeTransform.scala | 27 ++++++++++++++++--- test/test/transform/TreeTransformerTest.scala | 10 +++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/dotty/tools/dotc/transform/TreeTransform.scala b/src/dotty/tools/dotc/transform/TreeTransform.scala index a70ab8aed71b..e12f6a8a6fd5 100644 --- a/src/dotty/tools/dotc/transform/TreeTransform.scala +++ b/src/dotty/tools/dotc/transform/TreeTransform.scala @@ -92,6 +92,8 @@ object TreeTransforms { def prepareForPackageDef(tree: PackageDef)(implicit ctx: Context) = this def prepareForStats(trees: List[Tree])(implicit ctx: Context) = this + def prepareForUnit(tree: Tree)(implicit ctx: Context) = this + def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = tree def transformSelect(tree: Select)(implicit ctx: Context, info: TransformerInfo): Tree = tree def transformThis(tree: This)(implicit ctx: Context, info: TransformerInfo): Tree = tree @@ -125,6 +127,8 @@ object TreeTransforms { def transformStats(trees: List[Tree])(implicit ctx: Context, info: TransformerInfo): List[Tree] = trees def transformOther(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = tree + def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = tree + /** Transform tree using all transforms of current group (including this one) */ def transform(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = info.group.transform(tree, info, 0) @@ -273,6 +277,7 @@ object TreeTransforms { nxPrepTemplate = index(transformations, "prepareForTemplate") nxPrepPackageDef = index(transformations, "prepareForPackageDef") nxPrepStats = index(transformations, "prepareForStats") + nxPrepUnit = index(transformations, "prepareForUnit") nxTransIdent = index(transformations, "transformIdent") nxTransSelect = index(transformations, "transformSelect") @@ -305,6 +310,7 @@ object TreeTransforms { nxTransTemplate = index(transformations, "transformTemplate") nxTransPackageDef = index(transformations, "transformPackageDef") nxTransStats = index(transformations, "transformStats") + nxTransUnit = index(transformations, "transformUnit") nxTransOther = index(transformations, "transformOther") } @@ -412,6 +418,7 @@ object TreeTransforms { var nxPrepTemplate: Array[Int] = _ var nxPrepPackageDef: Array[Int] = _ var nxPrepStats: Array[Int] = _ + var nxPrepUnit: Array[Int] = _ var nxTransIdent: Array[Int] = _ var nxTransSelect: Array[Int] = _ @@ -444,6 +451,7 @@ object TreeTransforms { var nxTransTemplate: Array[Int] = _ var nxTransPackageDef: Array[Int] = _ var nxTransStats: Array[Int] = _ + var nxTransUnit: Array[Int] = _ var nxTransOther: Array[Int] = _ } @@ -454,7 +462,7 @@ object TreeTransforms { override def run(implicit ctx: Context): Unit = { val curTree = ctx.compilationUnit.tpdTree - val newTree = transform(curTree) + val newTree = macroTransform(curTree) ctx.compilationUnit.tpdTree = newTree } @@ -517,8 +525,9 @@ object TreeTransforms { val prepForTemplate: Mutator[Template] = (trans, tree, ctx) => trans.prepareForTemplate(tree)(ctx) val prepForPackageDef: Mutator[PackageDef] = (trans, tree, ctx) => trans.prepareForPackageDef(tree)(ctx) val prepForStats: Mutator[List[Tree]] = (trans, trees, ctx) => trans.prepareForStats(trees)(ctx) + val prepForUnit: Mutator[Tree] = (trans, tree, ctx) => trans.prepareForUnit(tree)(ctx) - def transform(t: Tree)(implicit ctx: Context): Tree = { + def macroTransform(t: Tree)(implicit ctx: Context): Tree = { val initialTransformations = transformations val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this) initialTransformations.zipWithIndex.foreach { @@ -526,7 +535,9 @@ object TreeTransforms { transform.idx = id transform.init(ctx, info) } - transform(t, info, 0) + implicit val mutatedInfo: TransformerInfo = mutateTransformers(info, prepForUnit, info.nx.nxPrepUnit, t, 0) + if (mutatedInfo eq null) t + else goUnit(transform(t, mutatedInfo, 0), mutatedInfo.nx.nxTransUnit(0)) } @tailrec @@ -859,6 +870,15 @@ object TreeTransforms { } else tree } + @tailrec + final private[TreeTransforms] def goUnit(tree: Tree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { + if (cur < info.transformers.length) { + val trans = info.transformers(cur) + val t = trans.transformUnit(tree)(ctx.withPhase(trans.treeTransformPhase), info) + goUnit(t, info.nx.nxTransUnit(cur + 1)) + } else tree + } + final private[TreeTransforms] def goOther(tree: Tree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) @@ -1219,5 +1239,4 @@ object TreeTransforms { def transformSubTrees[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] = transformTrees(trees, info, current)(ctx).asInstanceOf[List[Tr]] } - } diff --git a/test/test/transform/TreeTransformerTest.scala b/test/test/transform/TreeTransformerTest.scala index a1839f2a18f9..fadc44ab9495 100644 --- a/test/test/transform/TreeTransformerTest.scala +++ b/test/test/transform/TreeTransformerTest.scala @@ -24,7 +24,7 @@ class TreeTransformerTest extends DottyTest { override def phaseName: String = "test" } - val transformed = transformer.transform(tree) + val transformed = transformer.macroTransform(tree) Assert.assertTrue("returns same tree if unmodified", tree eq transformed @@ -46,7 +46,7 @@ class TreeTransformerTest extends DottyTest { override def phaseName: String = "test" } - val transformed = transformer.transform(tree) + val transformed = transformer.macroTransform(tree) Assert.assertTrue("returns same tree if unmodified", transformed.toString.contains("List(ValDef(Modifiers(,,List()),d,TypeTree[TypeRef(ThisType(module class scala),Int)],Literal(Constant(2)))") @@ -77,7 +77,7 @@ class TreeTransformerTest extends DottyTest { override def phaseName: String = "test" } - val tr = transformer.transform(tree).toString + val tr = transformer.macroTransform(tree).toString Assert.assertTrue("node can rewrite children", tr.contains("Literal(Constant(2))") && !tr.contains("Literal(Constant(-1))") @@ -123,7 +123,7 @@ class TreeTransformerTest extends DottyTest { override def phaseName: String = "test" } - val tr = transformer.transform(tree).toString + val tr = transformer.macroTransform(tree).toString Assert.assertTrue("node can rewrite children", tr.contains("Literal(Constant(3))") @@ -191,7 +191,7 @@ class TreeTransformerTest extends DottyTest { override def phaseName: String = "test" } - val tr = transformer.transform(tree).toString + val tr = transformer.macroTransform(tree).toString Assert.assertTrue("transformations aren't invoked multiple times", transformed1 == 2 && transformed2 == 3 ) From e65f8dae23fac56962a8f27676b0082e4e287c37 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 12 Nov 2014 11:22:33 +0100 Subject: [PATCH 2/6] Replaced overridden init methods with prepareForUnit. This allows to move to a functional implementation later. Only exception: CapturedVars still uses init() because it contains a (dubious!) interaction with intialization and transformSym. Looking at this next. --- .../dotc/transform/CollectEntryPoints.scala | 3 ++- .../dotc/transform/InterceptedMethods.scala | 3 ++- .../tools/dotc/transform/LambdaLift.scala | 20 ++++++++++++------- .../dotc/transform/SyntheticMethods.scala | 3 ++- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/dotty/tools/dotc/transform/CollectEntryPoints.scala b/src/dotty/tools/dotc/transform/CollectEntryPoints.scala index 7d854b8c9029..1109d1f904d1 100644 --- a/src/dotty/tools/dotc/transform/CollectEntryPoints.scala +++ b/src/dotty/tools/dotc/transform/CollectEntryPoints.scala @@ -26,9 +26,10 @@ import dotty.tools.dotc.config.JavaPlatform class CollectEntryPoints extends MiniPhaseTransform { /** perform context-dependant initialization */ - override def init(implicit ctx: Context, info: TransformerInfo): Unit = { + override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context) = { entryPoints = collection.immutable.TreeSet.empty[Symbol](new SymbolOrdering()) assert(ctx.platform.isInstanceOf[JavaPlatform], "Java platform specific phase") + this } private var entryPoints: Set[Symbol] = _ diff --git a/src/dotty/tools/dotc/transform/InterceptedMethods.scala b/src/dotty/tools/dotc/transform/InterceptedMethods.scala index 463ab86c42c8..c4f5d4dac0dd 100644 --- a/src/dotty/tools/dotc/transform/InterceptedMethods.scala +++ b/src/dotty/tools/dotc/transform/InterceptedMethods.scala @@ -53,11 +53,12 @@ class InterceptedMethods extends MiniPhaseTransform { private var primitiveGetClassMethods: Set[Symbol] = _ /** perform context-dependant initialization */ - override def init(implicit ctx: Context, info: TransformerInfo): Unit = { + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { poundPoundMethods = Set(defn.Any_##) Any_comparisons = Set(defn.Any_==, defn.Any_!=) interceptedMethods = poundPoundMethods ++ Any_comparisons primitiveGetClassMethods = Set[Symbol]() ++ defn.ScalaValueClasses.map(x => x.requiredMethod(nme.getClass_)) + this } // this should be removed if we have guarantee that ## will get Apply node diff --git a/src/dotty/tools/dotc/transform/LambdaLift.scala b/src/dotty/tools/dotc/transform/LambdaLift.scala index b239008fbc76..6a25d2bbcc46 100644 --- a/src/dotty/tools/dotc/transform/LambdaLift.scala +++ b/src/dotty/tools/dotc/transform/LambdaLift.scala @@ -278,20 +278,26 @@ class LambdaLift extends MiniPhaseTransform with IdentityDenotTransformer { this } } - override def init(implicit ctx: Context, info: TransformerInfo) = + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { ctx.atPhase(thisTransform) { implicit ctx => - free.clear() - proxyMap.clear() - called.clear() - calledFromInner.clear() - liftedOwner.clear() - liftedDefs.clear() (new CollectDependencies).traverse(NoSymbol, ctx.compilationUnit.tpdTree) computeFreeVars() computeLiftedOwners() generateProxies()(ctx.withPhase(thisTransform.next)) liftLocals()(ctx.withPhase(thisTransform.next)) } + this + } + + override def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo) = { + free.clear() + proxyMap.clear() + called.clear() + calledFromInner.clear() + liftedOwner.clear() + liftedDefs.clear() + tree + } private def currentEnclosure(implicit ctx: Context) = ctx.owner.enclosingMethod.skipConstructor diff --git a/src/dotty/tools/dotc/transform/SyntheticMethods.scala b/src/dotty/tools/dotc/transform/SyntheticMethods.scala index 128449efa73a..4e10b4aaff28 100644 --- a/src/dotty/tools/dotc/transform/SyntheticMethods.scala +++ b/src/dotty/tools/dotc/transform/SyntheticMethods.scala @@ -39,9 +39,10 @@ class SyntheticMethods extends MiniPhaseTransform with IdentityDenotTransformer private var valueSymbols: List[Symbol] = _ private var caseSymbols: List[Symbol] = _ - override def init(implicit ctx: Context, info: TransformerInfo) = { + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { valueSymbols = List(defn.Any_hashCode, defn.Any_equals) caseSymbols = valueSymbols ++ List(defn.Any_toString, defn.Product_canEqual, defn.Product_productArity) + this } /** The synthetic methods of the case or value class `clazz`. From 46dd4a8d78a955dccee8674df3a962b5dae17856 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 12 Nov 2014 17:46:48 +0100 Subject: [PATCH 3/6] Make CaptguredVars use prepareForUnit instead of init. Required some refactoring. Instead of transformSym we now transform ValDefs as we prepare for them. The previous scheme could not control directly whetrher transformSym or collectCaptured would run first. Turns out that init ran before collectCaptured but prepareForUnit did not, leading to test failures in pos/capturedvars. --- .../tools/dotc/transform/CapturedVars.scala | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/dotty/tools/dotc/transform/CapturedVars.scala b/src/dotty/tools/dotc/transform/CapturedVars.scala index 91b81fea5d0a..b2118d3ec675 100644 --- a/src/dotty/tools/dotc/transform/CapturedVars.scala +++ b/src/dotty/tools/dotc/transform/CapturedVars.scala @@ -17,7 +17,7 @@ import SymUtils._ import collection.{ mutable, immutable } import collection.mutable.{ LinkedHashMap, LinkedHashSet, TreeSet } -class CapturedVars extends MiniPhaseTransform with SymTransformer { thisTransform => +class CapturedVars extends MiniPhaseTransform with IdentityDenotTransformer { thisTransform => import ast.tpd._ /** the following two members override abstract members in Transform */ @@ -44,16 +44,10 @@ class CapturedVars extends MiniPhaseTransform with SymTransformer { thisTransfor } } - override def init(implicit ctx: Context, info: TransformerInfo): Unit = + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { (new CollectCaptured)(ctx.withPhase(thisTransform)).runOver(ctx.compilationUnit.tpdTree) - - override def transformSym(sd: SymDenotation)(implicit ctx: Context): SymDenotation = - if (captured(sd.symbol)) { - val newd = sd.copySymDenotation( - info = refCls(sd.info.classSymbol, sd.hasAnnotation(defn.VolatileAnnot)).typeRef) - newd.removeAnnotation(defn.VolatileAnnot) - newd - } else sd + this + } /** The {Volatile|}{Int|Double|...|Object}Ref class corresponding to the class `cls`, * depending on whether the reference should be @volatile @@ -68,6 +62,17 @@ class CapturedVars extends MiniPhaseTransform with SymTransformer { thisTransfor refCls(oldInfo.classSymbol, vble.isVolatile).typeRef } + override def prepareForValDef(vdef: ValDef)(implicit ctx: Context) = { + val sym = vdef.symbol + if (captured contains sym) { + val newd = sym.denot(ctx.withPhase(thisTransform)).copySymDenotation( + info = refCls(sym.info.classSymbol, sym.hasAnnotation(defn.VolatileAnnot)).typeRef) + newd.removeAnnotation(defn.VolatileAnnot) + newd.installAfter(thisTransform) + } + this + } + override def transformValDef(vdef: ValDef)(implicit ctx: Context, info: TransformerInfo): Tree = { val vble = vdef.symbol if (captured contains vble) { From c3b11ceddd46df7ecb2fb5640fd30efcd82b74c2 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 12 Nov 2014 18:09:45 +0100 Subject: [PATCH 4/6] Remove init method from TreeTransform Do not lead to temptation... --- src/dotty/tools/dotc/transform/TreeTransform.scala | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/dotty/tools/dotc/transform/TreeTransform.scala b/src/dotty/tools/dotc/transform/TreeTransform.scala index e12f6a8a6fd5..d3d8a183b848 100644 --- a/src/dotty/tools/dotc/transform/TreeTransform.scala +++ b/src/dotty/tools/dotc/transform/TreeTransform.scala @@ -142,9 +142,6 @@ object TreeTransforms { val last = info.transformers(info.transformers.length - 1) action(ctx.withPhase(last.phase.next)) } - - /** perform context-dependant initialization */ - def init(implicit ctx: Context, info: TransformerInfo): Unit = {} } /** A phase that defines a TreeTransform to be used in a group */ @@ -531,9 +528,7 @@ object TreeTransforms { val initialTransformations = transformations val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this) initialTransformations.zipWithIndex.foreach { - case (transform, id) => - transform.idx = id - transform.init(ctx, info) + case (transform, id) => transform.idx = id } implicit val mutatedInfo: TransformerInfo = mutateTransformers(info, prepForUnit, info.nx.nxPrepUnit, t, 0) if (mutatedInfo eq null) t From 8aae56b4f08806eda9745b0a980da151c0a7dc8c Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 12 Nov 2014 18:10:47 +0100 Subject: [PATCH 5/6] Make CapturedVars a functional transform. No global side effect on capturedVars anymore. --- .../tools/dotc/transform/CapturedVars.scala | 141 +++++++++--------- 1 file changed, 72 insertions(+), 69 deletions(-) diff --git a/src/dotty/tools/dotc/transform/CapturedVars.scala b/src/dotty/tools/dotc/transform/CapturedVars.scala index b2118d3ec675..14bb4a738cc0 100644 --- a/src/dotty/tools/dotc/transform/CapturedVars.scala +++ b/src/dotty/tools/dotc/transform/CapturedVars.scala @@ -17,91 +17,94 @@ import SymUtils._ import collection.{ mutable, immutable } import collection.mutable.{ LinkedHashMap, LinkedHashSet, TreeSet } -class CapturedVars extends MiniPhaseTransform with IdentityDenotTransformer { thisTransform => +class CapturedVars extends MiniPhase with IdentityDenotTransformer { thisTransform => import ast.tpd._ /** the following two members override abstract members in Transform */ val phaseName: String = "capturedVars" + val treeTransform = new Transform(Set()) - override def treeTransformPhase = thisTransform.next + class Transform(captured: collection.Set[Symbol]) extends TreeTransform { + def phase = thisTransform + override def treeTransformPhase = thisTransform.next - private var captured: mutable.HashSet[Symbol] = _ - - private class CollectCaptured(implicit ctx: Context) extends EnclosingMethodTraverser { - def traverse(enclMeth: Symbol, tree: Tree) = tree match { - case id: Ident => - val sym = id.symbol - if (sym.is(Mutable, butNot = Method) && sym.owner.isTerm && sym.enclosingMethod != enclMeth) { - ctx.log(i"capturing $sym in ${sym.enclosingMethod}, referenced from $enclMeth") - captured += sym - } - case _ => - foldOver(enclMeth, tree) - } - def runOver(tree: Tree) = { - captured = mutable.HashSet() - apply(NoSymbol, tree) + private class CollectCaptured(implicit ctx: Context) extends EnclosingMethodTraverser { + private val captured = mutable.HashSet[Symbol]() + def traverse(enclMeth: Symbol, tree: Tree) = tree match { + case id: Ident => + val sym = id.symbol + if (sym.is(Mutable, butNot = Method) && sym.owner.isTerm && sym.enclosingMethod != enclMeth) { + ctx.log(i"capturing $sym in ${sym.enclosingMethod}, referenced from $enclMeth") + captured += sym + } + case _ => + foldOver(enclMeth, tree) + } + def runOver(tree: Tree): collection.Set[Symbol] = { + apply(NoSymbol, tree) + captured + } } - } - override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { - (new CollectCaptured)(ctx.withPhase(thisTransform)).runOver(ctx.compilationUnit.tpdTree) - this - } + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { + val captured = (new CollectCaptured)(ctx.withPhase(thisTransform)) + .runOver(ctx.compilationUnit.tpdTree) + new Transform(captured) + } - /** The {Volatile|}{Int|Double|...|Object}Ref class corresponding to the class `cls`, - * depending on whether the reference should be @volatile - */ - def refCls(cls: Symbol, isVolatile: Boolean)(implicit ctx: Context): Symbol = { - val refMap = if (isVolatile) defn.volatileRefClass else defn.refClass - refMap.getOrElse(cls, refMap(defn.ObjectClass)) - } + /** The {Volatile|}{Int|Double|...|Object}Ref class corresponding to the class `cls`, + * depending on whether the reference should be @volatile + */ + def refCls(cls: Symbol, isVolatile: Boolean)(implicit ctx: Context): Symbol = { + val refMap = if (isVolatile) defn.volatileRefClass else defn.refClass + refMap.getOrElse(cls, refMap(defn.ObjectClass)) + } - def capturedType(vble: Symbol)(implicit ctx: Context): Type = { - val oldInfo = vble.denot(ctx.withPhase(thisTransform)).info - refCls(oldInfo.classSymbol, vble.isVolatile).typeRef - } + def capturedType(vble: Symbol)(implicit ctx: Context): Type = { + val oldInfo = vble.denot(ctx.withPhase(thisTransform)).info + refCls(oldInfo.classSymbol, vble.isVolatile).typeRef + } - override def prepareForValDef(vdef: ValDef)(implicit ctx: Context) = { - val sym = vdef.symbol - if (captured contains sym) { - val newd = sym.denot(ctx.withPhase(thisTransform)).copySymDenotation( - info = refCls(sym.info.classSymbol, sym.hasAnnotation(defn.VolatileAnnot)).typeRef) - newd.removeAnnotation(defn.VolatileAnnot) - newd.installAfter(thisTransform) + override def prepareForValDef(vdef: ValDef)(implicit ctx: Context) = { + val sym = vdef.symbol + if (captured contains sym) { + val newd = sym.denot(ctx.withPhase(thisTransform)).copySymDenotation( + info = refCls(sym.info.classSymbol, sym.hasAnnotation(defn.VolatileAnnot)).typeRef) + newd.removeAnnotation(defn.VolatileAnnot) + newd.installAfter(thisTransform) + } + this } - this - } - override def transformValDef(vdef: ValDef)(implicit ctx: Context, info: TransformerInfo): Tree = { - val vble = vdef.symbol - if (captured contains vble) { - def boxMethod(name: TermName): Tree = - ref(vble.info.classSymbol.companionModule.info.member(name).symbol) - cpy.ValDef(vdef)( - rhs = vdef.rhs match { - case EmptyTree => boxMethod(nme.zero).appliedToNone.withPos(vdef.pos) - case arg => boxMethod(nme.create).appliedTo(arg) - }, - tpt = TypeTree(vble.info).withPos(vdef.tpt.pos)) + override def transformValDef(vdef: ValDef)(implicit ctx: Context, info: TransformerInfo): Tree = { + val vble = vdef.symbol + if (captured contains vble) { + def boxMethod(name: TermName): Tree = + ref(vble.info.classSymbol.companionModule.info.member(name).symbol) + cpy.ValDef(vdef)( + rhs = vdef.rhs match { + case EmptyTree => boxMethod(nme.zero).appliedToNone.withPos(vdef.pos) + case arg => boxMethod(nme.create).appliedTo(arg) + }, + tpt = TypeTree(vble.info).withPos(vdef.tpt.pos)) + } else vdef } - else vdef - } - override def transformIdent(id: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = { - val vble = id.symbol - if (captured(vble)) - (id select nme.elem).ensureConforms(vble.denot(ctx.withPhase(thisTransform)).info) - else id - } + override def transformIdent(id: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = { + val vble = id.symbol + if (captured(vble)) + (id select nme.elem).ensureConforms(vble.denot(ctx.withPhase(thisTransform)).info) + else id + } - override def transformAssign(tree: Assign)(implicit ctx: Context, info: TransformerInfo): Tree = { - val lhs1 = tree.lhs match { - case TypeApply(Select(qual @ Select(qual2, nme.elem), nme.asInstanceOf_), _) => - assert(captured(qual2.symbol)) - qual - case _ => tree.lhs + override def transformAssign(tree: Assign)(implicit ctx: Context, info: TransformerInfo): Tree = { + val lhs1 = tree.lhs match { + case TypeApply(Select(qual @ Select(qual2, nme.elem), nme.asInstanceOf_), _) => + assert(captured(qual2.symbol)) + qual + case _ => tree.lhs + } + cpy.Assign(tree)(lhs1, tree.rhs) } - cpy.Assign(tree)(lhs1, tree.rhs) } } \ No newline at end of file From e02db3453582bfab8600664f9c03bd8f7285be41 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 12 Nov 2014 18:36:07 +0100 Subject: [PATCH 6/6] Eliminate global state in LambdaLift State moved into local transforms which are allocated one per unit. Thsi allows lambda lifters on different units to run in parallel. --- .../tools/dotc/transform/LambdaLift.scala | 625 +++++++++--------- 1 file changed, 311 insertions(+), 314 deletions(-) diff --git a/src/dotty/tools/dotc/transform/LambdaLift.scala b/src/dotty/tools/dotc/transform/LambdaLift.scala index 6a25d2bbcc46..944dac208688 100644 --- a/src/dotty/tools/dotc/transform/LambdaLift.scala +++ b/src/dotty/tools/dotc/transform/LambdaLift.scala @@ -26,12 +26,15 @@ object LambdaLift { private class NoPath extends Exception } -class LambdaLift extends MiniPhaseTransform with IdentityDenotTransformer { thisTransform => +class LambdaLift extends MiniPhase with IdentityDenotTransformer { thisTransform => import LambdaLift._ import ast.tpd._ /** the following two members override abstract members in Transform */ val phaseName: String = "lambdaLift" + val treeTransform = new LambdaLifter + + override def relaxedTyping = true override def runsAfter: Set[Class[_ <: Phase]] = Set(classOf[Constructors]) // Constructors has to happen before LambdaLift because the lambda lift logic @@ -40,377 +43,371 @@ class LambdaLift extends MiniPhaseTransform with IdentityDenotTransformer { this // lambda lift for super calls right. Witness the implementation restrictions to // this effect in scalac. - override def treeTransformPhase = thisTransform.next - override def relaxedTyping = true + class LambdaLifter extends TreeTransform { + override def phase = thisTransform + override def treeTransformPhase = thisTransform.next - private type SymSet = TreeSet[Symbol] + private type SymSet = TreeSet[Symbol] - /** A map storing free variables of functions and classes */ - private val free = new LinkedHashMap[Symbol, SymSet] + /** A map storing free variables of functions and classes */ + private val free = new LinkedHashMap[Symbol, SymSet] - /** A map storing the free variable proxies of functions and classes. - * For every function and class, this is a map from the free variables - * of that function or class to the proxy symbols accessing them. - */ - private val proxyMap = new LinkedHashMap[Symbol, Map[Symbol, Symbol]] + /** A map storing the free variable proxies of functions and classes. + * For every function and class, this is a map from the free variables + * of that function or class to the proxy symbols accessing them. + */ + private val proxyMap = new LinkedHashMap[Symbol, Map[Symbol, Symbol]] - /** A hashtable storing calls between functions */ - private val called = new LinkedHashMap[Symbol, SymSet] + /** A hashtable storing calls between functions */ + private val called = new LinkedHashMap[Symbol, SymSet] - /** Symbols that are called from an inner class. */ - private val calledFromInner = new HashSet[Symbol] + /** Symbols that are called from an inner class. */ + private val calledFromInner = new HashSet[Symbol] - /** A map from local methods and classes to the owners to which they will be lifted as members. - * For methods and classes that do not have any dependencies this will be the enclosing package. - * symbols with packages as lifted owners will subsequently represented as static - * members of their toplevel class. - */ - private val liftedOwner = new HashMap[Symbol, Symbol] + /** A map from local methods and classes to the owners to which they will be lifted as members. + * For methods and classes that do not have any dependencies this will be the enclosing package. + * symbols with packages as lifted owners will subsequently represented as static + * members of their toplevel class. + */ + private val liftedOwner = new HashMap[Symbol, Symbol] - /** Buffers for lifted out classes and methods, indexed by owner */ - private val liftedDefs = new HashMap[Symbol, mutable.ListBuffer[Tree]] + /** Buffers for lifted out classes and methods, indexed by owner */ + private val liftedDefs = new HashMap[Symbol, mutable.ListBuffer[Tree]] - /** A flag to indicate whether new free variables have been found */ - private var changedFreeVars: Boolean = _ + /** A flag to indicate whether new free variables have been found */ + private var changedFreeVars: Boolean = _ - /** A flag to indicate whether lifted owners have changed */ - private var changedLiftedOwner: Boolean = _ + /** A flag to indicate whether lifted owners have changed */ + private var changedLiftedOwner: Boolean = _ - private val ord: Ordering[Symbol] = Ordering.by((_: Symbol).id) // Dotty deviation: Type annotation needed. TODO: figure out why - private def newSymSet = TreeSet.empty[Symbol](ord) + private val ord: Ordering[Symbol] = Ordering.by((_: Symbol).id) // Dotty deviation: Type annotation needed. TODO: figure out why + private def newSymSet = TreeSet.empty[Symbol](ord) - private def symSet(f: LinkedHashMap[Symbol, SymSet], sym: Symbol): SymSet = - f.getOrElseUpdate(sym, newSymSet) + private def symSet(f: LinkedHashMap[Symbol, SymSet], sym: Symbol): SymSet = + f.getOrElseUpdate(sym, newSymSet) - def proxies(sym: Symbol): List[Symbol] = { - val pm: Map[Symbol, Symbol] = proxyMap.getOrElse(sym, Map.empty) // Dotty deviation: Type annotation needed. TODO: figure out why - free.getOrElse(sym, Nil).toList.map(pm) - } + def proxies(sym: Symbol): List[Symbol] = { + val pm: Map[Symbol, Symbol] = proxyMap.getOrElse(sym, Map.empty) // Dotty deviation: Type annotation needed. TODO: figure out why + free.getOrElse(sym, Nil).toList.map(pm) + } - def narrowLiftedOwner(sym: Symbol, owner: Symbol)(implicit ctx: Context) = { - ctx.log(i"narrow lifted $sym to $owner") - if (sym.owner.skipConstructor.isTerm && + def narrowLiftedOwner(sym: Symbol, owner: Symbol)(implicit ctx: Context) = { + ctx.log(i"narrow lifted $sym to $owner") + if (sym.owner.skipConstructor.isTerm && owner.isProperlyContainedIn(liftedOwner(sym))) { - changedLiftedOwner = true - liftedOwner(sym) = owner + changedLiftedOwner = true + liftedOwner(sym) = owner + } } - } - /** Mark symbol `sym` as being free in `enclosure`, unless `sym` - * is defined in `enclosure` or there is a class between `enclosure`s owner - * and the owner of `sym`. - * Return `true` if there is no class between `enclosure` and - * the owner of sym. - * pre: sym.owner.isTerm, (enclosure.isMethod || enclosure.isClass) - * - * The idea of `markFree` is illustrated with an example: - * - * def f(x: int) = { - * class C { - * class D { - * val y = x - * } - * } - * } - * - * In this case `x` is free in the primary constructor of class `C`. - * but it is not free in `D`, because after lambda lift the code would be transformed - * as follows: - * - * def f(x$0: int) { - * class C(x$0: int) { - * val x$1 = x$0 - * class D { - * val y = outer.x$1 - * } - * } - * } - */ - private def markFree(sym: Symbol, enclosure: Symbol)(implicit ctx: Context): Boolean = try { - if (!enclosure.exists) throw new NoPath - ctx.log(i"mark free: ${sym.showLocated} with owner ${sym.maybeOwner} marked free in $enclosure") - (enclosure == sym.enclosure) || { - ctx.debuglog(i"$enclosure != ${sym.enclosure}") - narrowLiftedOwner(enclosure, sym.enclosingClass) - if (enclosure.is(PackageClass) || + /** Mark symbol `sym` as being free in `enclosure`, unless `sym` + * is defined in `enclosure` or there is a class between `enclosure`s owner + * and the owner of `sym`. + * Return `true` if there is no class between `enclosure` and + * the owner of sym. + * pre: sym.owner.isTerm, (enclosure.isMethod || enclosure.isClass) + * + * The idea of `markFree` is illustrated with an example: + * + * def f(x: int) = { + * class C { + * class D { + * val y = x + * } + * } + * } + * + * In this case `x` is free in the primary constructor of class `C`. + * but it is not free in `D`, because after lambda lift the code would be transformed + * as follows: + * + * def f(x$0: int) { + * class C(x$0: int) { + * val x$1 = x$0 + * class D { + * val y = outer.x$1 + * } + * } + * } + */ + private def markFree(sym: Symbol, enclosure: Symbol)(implicit ctx: Context): Boolean = try { + if (!enclosure.exists) throw new NoPath + ctx.log(i"mark free: ${sym.showLocated} with owner ${sym.maybeOwner} marked free in $enclosure") + (enclosure == sym.enclosure) || { + ctx.debuglog(i"$enclosure != ${sym.enclosure}") + narrowLiftedOwner(enclosure, sym.enclosingClass) + if (enclosure.is(PackageClass) || !markFree(sym, enclosure.skipConstructor.enclosure)) false - else { - val ss = symSet(free, enclosure) - if (!ss(sym)) { - ss += sym - changedFreeVars = true - ctx.debuglog(i"$sym is free in $enclosure") + else { + val ss = symSet(free, enclosure) + if (!ss(sym)) { + ss += sym + changedFreeVars = true + ctx.debuglog(i"$sym is free in $enclosure") + } + !enclosure.isClass } - !enclosure.isClass } + } catch { + case ex: NoPath => + println(i"error lambda lifting ${ctx.compilationUnit}: $sym is not visible from $enclosure") + throw ex } - } - catch { - case ex: NoPath => - println(i"error lambda lifting ${ctx.compilationUnit}: $sym is not visible from $enclosure") - throw ex - } - private def markCalled(callee: Symbol, caller: Symbol)(implicit ctx: Context): Unit = { - ctx.debuglog(i"mark called: $callee of ${callee.owner} is called by $caller") - assert(callee.skipConstructor.owner.isTerm) - symSet(called, caller) += callee - if (callee.enclosingClass != caller.enclosingClass) calledFromInner += callee - } + private def markCalled(callee: Symbol, caller: Symbol)(implicit ctx: Context): Unit = { + ctx.debuglog(i"mark called: $callee of ${callee.owner} is called by $caller") + assert(callee.skipConstructor.owner.isTerm) + symSet(called, caller) += callee + if (callee.enclosingClass != caller.enclosingClass) calledFromInner += callee + } - private class CollectDependencies(implicit ctx: Context) extends EnclosingMethodTraverser { - def traverse(enclMeth: Symbol, tree: Tree) = try { //debug - val enclosure = enclMeth.skipConstructor - val sym = tree.symbol - tree match { - case tree: Ident => - if (sym.maybeOwner.isTerm) { - if (sym is Label) - assert(enclosure == sym.enclosure, + private class CollectDependencies(implicit ctx: Context) extends EnclosingMethodTraverser { + def traverse(enclMeth: Symbol, tree: Tree) = try { //debug + val enclosure = enclMeth.skipConstructor + val sym = tree.symbol + tree match { + case tree: Ident => + if (sym.maybeOwner.isTerm) { + if (sym is Label) + assert(enclosure == sym.enclosure, i"attempt to refer to label $sym from nested $enclosure") - else if (sym is Method) markCalled(sym, enclosure) - else if (sym.isTerm) markFree(sym, enclosure) - } - case tree: Select => - if (sym.isConstructor && sym.owner.owner.isTerm) - markCalled(sym, enclosure) - case tree: This => - val thisClass = tree.symbol.asClass - val enclClass = enclosure.enclosingClass - if (!thisClass.isStaticOwner && thisClass != enclClass) - narrowLiftedOwner(enclosure, - if (enclClass.isContainedIn(thisClass)) thisClass - else enclClass) // unknown this reference, play it safe and assume the narrowest possible owner - case tree: DefDef => - if (sym.owner.isTerm && !sym.is(Label)) liftedOwner(sym) = sym.topLevelClass.owner - else if (sym.isPrimaryConstructor && sym.owner.owner.isTerm) symSet(called, sym) += sym.owner - case tree: TypeDef => - if (sym.owner.isTerm) liftedOwner(sym) = sym.topLevelClass.owner - case tree: Template => - liftedDefs(tree.symbol.owner) = new mutable.ListBuffer - case _ => + else if (sym is Method) markCalled(sym, enclosure) + else if (sym.isTerm) markFree(sym, enclosure) + } + case tree: Select => + if (sym.isConstructor && sym.owner.owner.isTerm) + markCalled(sym, enclosure) + case tree: This => + val thisClass = tree.symbol.asClass + val enclClass = enclosure.enclosingClass + if (!thisClass.isStaticOwner && thisClass != enclClass) + narrowLiftedOwner(enclosure, + if (enclClass.isContainedIn(thisClass)) thisClass + else enclClass) // unknown this reference, play it safe and assume the narrowest possible owner + case tree: DefDef => + if (sym.owner.isTerm && !sym.is(Label)) liftedOwner(sym) = sym.topLevelClass.owner + else if (sym.isPrimaryConstructor && sym.owner.owner.isTerm) symSet(called, sym) += sym.owner + case tree: TypeDef => + if (sym.owner.isTerm) liftedOwner(sym) = sym.topLevelClass.owner + case tree: Template => + liftedDefs(tree.symbol.owner) = new mutable.ListBuffer + case _ => + } + foldOver(enclosure, tree) + } catch { //debug + case ex: Exception => + println(i"$ex while traversing $tree") + throw ex } - foldOver(enclosure, tree) - } catch { //debug - case ex: Exception => - println(i"$ex while traversing $tree") - throw ex } - } - /** Compute final free variables map `fvs by closing over caller dependencies. */ - private def computeFreeVars()(implicit ctx: Context): Unit = - do { - changedFreeVars = false - for { - caller <- called.keys - callee <- called(caller) - fvs <- free get callee - fv <- fvs - } markFree(fv, caller) - } while (changedFreeVars) - - /** Compute final liftedOwner map by closing over caller dependencies */ - private def computeLiftedOwners()(implicit ctx: Context): Unit = - do { - changedLiftedOwner = false - for { - caller <- called.keys - callee <- called(caller) - } narrowLiftedOwner(caller, liftedOwner(callee.skipConstructor)) - } while (changedLiftedOwner) - - private def newName(sym: Symbol)(implicit ctx: Context): Name = { - def freshen(prefix: String): Name = { - val fname = ctx.freshName(prefix) - if (sym.isType) fname.toTypeName else fname.toTermName + /** Compute final free variables map `fvs by closing over caller dependencies. */ + private def computeFreeVars()(implicit ctx: Context): Unit = + do { + changedFreeVars = false + for { + caller <- called.keys + callee <- called(caller) + fvs <- free get callee + fv <- fvs + } markFree(fv, caller) + } while (changedFreeVars) + + /** Compute final liftedOwner map by closing over caller dependencies */ + private def computeLiftedOwners()(implicit ctx: Context): Unit = + do { + changedLiftedOwner = false + for { + caller <- called.keys + callee <- called(caller) + } narrowLiftedOwner(caller, liftedOwner(callee.skipConstructor)) + } while (changedLiftedOwner) + + private def newName(sym: Symbol)(implicit ctx: Context): Name = { + def freshen(prefix: String): Name = { + val fname = ctx.freshName(prefix) + if (sym.isType) fname.toTypeName else fname.toTermName + } + if (sym.isAnonymousFunction && sym.owner.is(Method, butNot = Label)) + freshen(sym.name.toString ++ NJ ++ sym.owner.name ++ NJ) + else if (sym is ModuleClass) + freshen(sym.sourceModule.name.toString ++ NJ).moduleClassName + else + freshen(sym.name.toString ++ NJ) } - if (sym.isAnonymousFunction && sym.owner.is(Method, butNot = Label)) - freshen(sym.name.toString ++ NJ ++ sym.owner.name ++ NJ) - else if (sym is ModuleClass) - freshen(sym.sourceModule.name.toString ++ NJ).moduleClassName - else - freshen(sym.name.toString ++ NJ) - } - private def generateProxies()(implicit ctx: Context): Unit = - for ((owner, freeValues) <- free.toIterator) { - val newFlags = Synthetic | (if (owner.isClass) ParamAccessor | Private else Param) - ctx.debuglog(i"free var proxy: ${owner.showLocated}, ${freeValues.toList}%, %") - proxyMap(owner) = { - for (fv <- freeValues.toList) yield { - val proxyName = newName(fv) - val proxy = ctx.newSymbol(owner, proxyName.asTermName, newFlags, fv.info, coord = fv.coord) - if (owner.isClass) proxy.enteredAfter(thisTransform) - (fv, proxy) - } - }.toMap - } + private def generateProxies()(implicit ctx: Context): Unit = + for ((owner, freeValues) <- free.toIterator) { + val newFlags = Synthetic | (if (owner.isClass) ParamAccessor | Private else Param) + ctx.debuglog(i"free var proxy: ${owner.showLocated}, ${freeValues.toList}%, %") + proxyMap(owner) = { + for (fv <- freeValues.toList) yield { + val proxyName = newName(fv) + val proxy = ctx.newSymbol(owner, proxyName.asTermName, newFlags, fv.info, coord = fv.coord) + if (owner.isClass) proxy.enteredAfter(thisTransform) + (fv, proxy) + } + }.toMap + } - private def liftedInfo(local: Symbol)(implicit ctx: Context): Type = local.info match { - case mt @ MethodType(pnames, ptypes) => - val ps = proxies(local.skipConstructor) - MethodType( - pnames ++ ps.map(_.name.asTermName), - ptypes ++ ps.map(_.info), - mt.resultType) - case info => info - } + private def liftedInfo(local: Symbol)(implicit ctx: Context): Type = local.info match { + case mt @ MethodType(pnames, ptypes) => + val ps = proxies(local.skipConstructor) + MethodType( + pnames ++ ps.map(_.name.asTermName), + ptypes ++ ps.map(_.info), + mt.resultType) + case info => info + } - private def liftLocals()(implicit ctx: Context): Unit = { - for ((local, lOwner) <- liftedOwner) { - val (newOwner, maybeStatic) = - if (lOwner is Package) (local.topLevelClass, JavaStatic) - else (lOwner, EmptyFlags) - val maybeNotJavaPrivate = if (calledFromInner(local)) NotJavaPrivate else EmptyFlags - local.copySymDenotation( - owner = newOwner, - name = newName(local), - initFlags = local.flags | Private | maybeStatic | maybeNotJavaPrivate, - info = liftedInfo(local)).installAfter(thisTransform) - if (local.isClass) - for (member <- local.asClass.decls) - if (member.isConstructor) { - val linfo = liftedInfo(member) - if (linfo ne member.info) - member.copySymDenotation(info = linfo).installAfter(thisTransform) - } + private def liftLocals()(implicit ctx: Context): Unit = { + for ((local, lOwner) <- liftedOwner) { + val (newOwner, maybeStatic) = + if (lOwner is Package) (local.topLevelClass, JavaStatic) + else (lOwner, EmptyFlags) + val maybeNotJavaPrivate = if (calledFromInner(local)) NotJavaPrivate else EmptyFlags + local.copySymDenotation( + owner = newOwner, + name = newName(local), + initFlags = local.flags | Private | maybeStatic | maybeNotJavaPrivate, + info = liftedInfo(local)).installAfter(thisTransform) + if (local.isClass) + for (member <- local.asClass.decls) + if (member.isConstructor) { + val linfo = liftedInfo(member) + if (linfo ne member.info) + member.copySymDenotation(info = linfo).installAfter(thisTransform) + } + } } - } - override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { - ctx.atPhase(thisTransform) { implicit ctx => + private def init(implicit ctx: Context) = { (new CollectDependencies).traverse(NoSymbol, ctx.compilationUnit.tpdTree) computeFreeVars() computeLiftedOwners() generateProxies()(ctx.withPhase(thisTransform.next)) liftLocals()(ctx.withPhase(thisTransform.next)) } - this - } - override def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo) = { - free.clear() - proxyMap.clear() - called.clear() - calledFromInner.clear() - liftedOwner.clear() - liftedDefs.clear() - tree - } + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { + val lifter = new LambdaLifter + lifter.init(ctx.withPhase(thisTransform)) + lifter + } - private def currentEnclosure(implicit ctx: Context) = - ctx.owner.enclosingMethod.skipConstructor + private def currentEnclosure(implicit ctx: Context) = + ctx.owner.enclosingMethod.skipConstructor - private def inCurrentOwner(sym: Symbol)(implicit ctx: Context) = - sym.enclosure == currentEnclosure + private def inCurrentOwner(sym: Symbol)(implicit ctx: Context) = + sym.enclosure == currentEnclosure - private def proxy(sym: Symbol)(implicit ctx: Context): Symbol = { - def searchIn(enclosure: Symbol): Symbol = { - if (!enclosure.exists) { - def enclosures(encl: Symbol): List[Symbol] = - if (encl.exists) encl :: enclosures(encl.enclosure) else Nil - throw new IllegalArgumentException(i"Could not find proxy for ${sym.showDcl} in ${sym.ownersIterator.toList}, encl = $currentEnclosure, owners = ${currentEnclosure.ownersIterator.toList}%, %; enclosures = ${enclosures(currentEnclosure)}%, %") - } - ctx.debuglog(i"searching for $sym(${sym.owner}) in $enclosure") - proxyMap get enclosure match { - case Some(pmap) => - pmap get sym match { - case Some(proxy) => return proxy - case none => - } - case none => + private def proxy(sym: Symbol)(implicit ctx: Context): Symbol = { + def searchIn(enclosure: Symbol): Symbol = { + if (!enclosure.exists) { + def enclosures(encl: Symbol): List[Symbol] = + if (encl.exists) encl :: enclosures(encl.enclosure) else Nil + throw new IllegalArgumentException(i"Could not find proxy for ${sym.showDcl} in ${sym.ownersIterator.toList}, encl = $currentEnclosure, owners = ${currentEnclosure.ownersIterator.toList}%, %; enclosures = ${enclosures(currentEnclosure)}%, %") + } + ctx.debuglog(i"searching for $sym(${sym.owner}) in $enclosure") + proxyMap get enclosure match { + case Some(pmap) => + pmap get sym match { + case Some(proxy) => return proxy + case none => + } + case none => + } + searchIn(enclosure.enclosure) } - searchIn(enclosure.enclosure) + if (inCurrentOwner(sym)) sym else searchIn(currentEnclosure) } - if (inCurrentOwner(sym)) sym else searchIn(currentEnclosure) - } - - private def memberRef(sym: Symbol)(implicit ctx: Context, info: TransformerInfo): Tree = { - val clazz = sym.enclosingClass - val qual = - if (clazz.isStaticOwner) singleton(clazz.thisType) - else outer(ctx.withPhase(thisTransform)).path(clazz) - transformFollowingDeep(qual.select(sym)) - } - private def proxyRef(sym: Symbol)(implicit ctx: Context, info: TransformerInfo): Tree = { - val psym = proxy(sym)(ctx.withPhase(thisTransform)) - transformFollowingDeep(if (psym.owner.isTerm) ref(psym) else memberRef(psym)) - } + private def memberRef(sym: Symbol)(implicit ctx: Context, info: TransformerInfo): Tree = { + val clazz = sym.enclosingClass + val qual = + if (clazz.isStaticOwner) singleton(clazz.thisType) + else outer(ctx.withPhase(thisTransform)).path(clazz) + transformFollowingDeep(qual.select(sym)) + } - private def addFreeArgs(sym: Symbol, args: List[Tree])(implicit ctx: Context, info: TransformerInfo) = - free get sym match { - case Some(fvs) => args ++ fvs.toList.map(proxyRef(_)) - case _ => args + private def proxyRef(sym: Symbol)(implicit ctx: Context, info: TransformerInfo): Tree = { + val psym = proxy(sym)(ctx.withPhase(thisTransform)) + transformFollowingDeep(if (psym.owner.isTerm) ref(psym) else memberRef(psym)) } - private def addFreeParams(tree: Tree, proxies: List[Symbol])(implicit ctx: Context, info: TransformerInfo): Tree = proxies match { - case Nil => tree - case proxies => - val ownProxies = - if (!tree.symbol.isConstructor) proxies - else proxies.map(_.copy(owner = tree.symbol, flags = Synthetic | Param)) - val freeParamDefs = ownProxies.map(proxy => - transformFollowingDeep(ValDef(proxy.asTerm).withPos(tree.pos)).asInstanceOf[ValDef]) - tree match { - case tree: DefDef => - cpy.DefDef(tree)(vparamss = tree.vparamss.map(_ ++ freeParamDefs)) - case tree: Template => - cpy.Template(tree)(body = tree.body ++ freeParamDefs) + private def addFreeArgs(sym: Symbol, args: List[Tree])(implicit ctx: Context, info: TransformerInfo) = + free get sym match { + case Some(fvs) => args ++ fvs.toList.map(proxyRef(_)) + case _ => args } - } - private def liftDef(tree: MemberDef)(implicit ctx: Context, info: TransformerInfo): Tree = { - val buf = liftedDefs(tree.symbol.owner) - transformFollowing(rename(tree, tree.symbol.name)).foreachInThicket(buf += _) - EmptyTree - } + private def addFreeParams(tree: Tree, proxies: List[Symbol])(implicit ctx: Context, info: TransformerInfo): Tree = proxies match { + case Nil => tree + case proxies => + val ownProxies = + if (!tree.symbol.isConstructor) proxies + else proxies.map(_.copy(owner = tree.symbol, flags = Synthetic | Param)) + val freeParamDefs = ownProxies.map(proxy => + transformFollowingDeep(ValDef(proxy.asTerm).withPos(tree.pos)).asInstanceOf[ValDef]) + tree match { + case tree: DefDef => + cpy.DefDef(tree)(vparamss = tree.vparamss.map(_ ++ freeParamDefs)) + case tree: Template => + cpy.Template(tree)(body = tree.body ++ freeParamDefs) + } + } - private def needsLifting(sym: Symbol) = liftedOwner contains sym + private def liftDef(tree: MemberDef)(implicit ctx: Context, info: TransformerInfo): Tree = { + val buf = liftedDefs(tree.symbol.owner) + transformFollowing(rename(tree, tree.symbol.name)).foreachInThicket(buf += _) + EmptyTree + } - override def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo) = { - val sym = tree.symbol - tree.tpe match { - case tpe @ TermRef(prefix, _) => - if ((prefix eq NoPrefix) && sym.enclosure != currentEnclosure && !sym.isStatic) - (if (sym is Method) memberRef(sym) else proxyRef(sym)).withPos(tree.pos) - else if (!prefixIsElidable(tpe)) ref(tpe) - else tree - case _ => - tree + private def needsLifting(sym: Symbol) = liftedOwner contains sym + + override def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo) = { + val sym = tree.symbol + tree.tpe match { + case tpe @ TermRef(prefix, _) => + if ((prefix eq NoPrefix) && sym.enclosure != currentEnclosure && !sym.isStatic) + (if (sym is Method) memberRef(sym) else proxyRef(sym)).withPos(tree.pos) + else if (!prefixIsElidable(tpe)) ref(tpe) + else tree + case _ => + tree + } } - } - override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) = - cpy.Apply(tree)(tree.fun, addFreeArgs(tree.symbol.skipConstructor, tree.args)).withPos(tree.pos) + override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) = + cpy.Apply(tree)(tree.fun, addFreeArgs(tree.symbol.skipConstructor, tree.args)).withPos(tree.pos) - override def transformClosure(tree: Closure)(implicit ctx: Context, info: TransformerInfo) = - cpy.Closure(tree)(env = addFreeArgs(tree.meth.symbol, tree.env)) + override def transformClosure(tree: Closure)(implicit ctx: Context, info: TransformerInfo) = + cpy.Closure(tree)(env = addFreeArgs(tree.meth.symbol, tree.env)) + + override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo) = { + val sym = tree.symbol + val proxyHolder = sym.skipConstructor + if (needsLifting(proxyHolder)) { + val paramsAdded = addFreeParams(tree, proxies(proxyHolder)).asInstanceOf[DefDef] + if (sym.isConstructor) paramsAdded else liftDef(paramsAdded) + } + else tree + } - override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo) = { - val sym = tree.symbol - val proxyHolder = sym.skipConstructor - if (needsLifting(proxyHolder)) { - val paramsAdded = addFreeParams(tree, proxies(proxyHolder)).asInstanceOf[DefDef] - if (sym.isConstructor) paramsAdded else liftDef(paramsAdded) + override def transformReturn(tree: Return)(implicit ctx: Context, info: TransformerInfo) = tree.expr match { + case Block(stats, value) => + Block(stats, Return(value, tree.from)).withPos(tree.pos) + case _ => + tree } - else tree - } - override def transformReturn(tree: Return)(implicit ctx: Context, info: TransformerInfo) = tree.expr match { - case Block(stats, value) => - Block(stats, Return(value, tree.from)).withPos(tree.pos) - case _ => - tree - } + override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = { + val cls = ctx.owner + val impl = addFreeParams(tree, proxies(cls)).asInstanceOf[Template] + cpy.Template(impl)(body = impl.body ++ liftedDefs.remove(cls).get) + } - override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = { - val cls = ctx.owner - val impl = addFreeParams(tree, proxies(cls)).asInstanceOf[Template] - cpy.Template(impl)(body = impl.body ++ liftedDefs.remove(cls).get) + override def transformTypeDef(tree: TypeDef)(implicit ctx: Context, info: TransformerInfo) = + if (needsLifting(tree.symbol)) liftDef(tree) else tree } - - override def transformTypeDef(tree: TypeDef)(implicit ctx: Context, info: TransformerInfo) = - if (needsLifting(tree.symbol)) liftDef(tree) else tree }