Skip to content

Commit bee8e37

Browse files
committed
New ShortcutImplicits phase
Optimizes implicit closures by avoiding closure creation where possible.
1 parent 50bc580 commit bee8e37

File tree

7 files changed

+195
-1
lines changed

7 files changed

+195
-1
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class Compiler {
6161
new PatternMatcher, // Compile pattern matches
6262
new ExplicitOuter, // Add accessors to outer classes from nested ones.
6363
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
64+
new ShortcutImplicits, // Allow implicit functions without creating closures
6465
new CrossCastAnd, // Normalize selections involving intersection types.
6566
new Splitter), // Expand selections involving union types into conditionals
6667
List(new VCInlineMethods, // Inlines calls to value class methods

compiler/src/dotty/tools/dotc/config/JavaPlatform.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class JavaPlatform extends Platform {
1818
currentClassPath = Some(new PathResolver().result)
1919
val cp = currentClassPath.get
2020
//println(cp)
21+
//println("------------------")
2122
cp
2223
}
2324

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ object NameOps {
188188

189189
def errorName: N = likeTyped(name ++ nme.ERROR)
190190

191+
def directName: N = likeTyped(name ++ DIRECT_SUFFIX)
192+
191193
def freshened(implicit ctx: Context): N =
192194
likeTyped(
193195
if (name.isModuleClassName) name.stripModuleClassSuffix.freshened.moduleClassName

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ object StdNames {
129129
val COMPANION_MODULE_METHOD: N = "companion$module"
130130
val COMPANION_CLASS_METHOD: N = "companion$class"
131131
val TRAIT_SETTER_SEPARATOR: N = "$_setter_$"
132+
val DIRECT_SUFFIX: N = "$direct"
132133

133134
// value types (and AnyRef) are all used as terms as well
134135
// as (at least) arguments to the @specialize annotation.

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2376,7 +2376,9 @@ object Types {
23762376
protected def computeSignature(implicit ctx: Context): Signature =
23772377
resultSignature.prepend(paramTypes, isJava)
23782378

2379-
def derivedMethodType(paramNames: List[TermName], paramTypes: List[Type], resType: Type)(implicit ctx: Context) =
2379+
def derivedMethodType(paramNames: List[TermName] = this.paramNames,
2380+
paramTypes: List[Type] = this.paramTypes,
2381+
resType: Type = this.resType)(implicit ctx: Context) =
23802382
if ((paramNames eq this.paramNames) && (paramTypes eq this.paramTypes) && (resType eq this.resType)) this
23812383
else {
23822384
val resTypeFn = (x: MethodType) => resType.subst(this, x)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import TreeTransforms._
5+
import core.DenotTransformers.IdentityDenotTransformer
6+
import core.Symbols._
7+
import core.Contexts._
8+
import core.Types._
9+
import core.Flags._
10+
import core.Decorators._
11+
import core.StdNames.nme
12+
import core.Names._
13+
import core.NameOps._
14+
import ast.Trees._
15+
import ast.tpd
16+
import collection.mutable
17+
18+
/** This phase optimizes code using implicit function types, by applying two rewrite rules.
19+
* Let IF be the implicit function type
20+
*
21+
* implicit Us => R
22+
*
23+
* (1) A method definition
24+
*
25+
* def m(xs: Ts): IF = implicit (ys: Us) => E
26+
*
27+
* is expanded to two methods:
28+
*
29+
* def m(xs: Ts): IF = implicit (ys: Us) => m$direct(xs)(ys)
30+
* def m$direct(xs: Ts)(ys: Us): R = E
31+
*
32+
* (and equivalently for methods with type parameters or a different number of value parameter lists).
33+
* An abstract method definition
34+
*
35+
* def m(xs: Ts): IF
36+
*
37+
* is expanded to:
38+
*
39+
* def m(xs: Ts): IF
40+
* def m$direct(xs: Ts, ys: Us): R
41+
*
42+
* (2) A reference `qual.apply` where `qual` has implicit function type and
43+
* `qual` refers to a method `m` is rewritten to a reference to `m$direct`,
44+
* keeping the same type and value arguments as they are found in `qual`.
45+
*/
46+
class ShortcutImplicits extends MiniPhase with IdentityDenotTransformer { thisTransform =>
47+
import tpd._
48+
49+
override def phaseName: String = "shortcutImplicits"
50+
val treeTransform = new Transform
51+
52+
class Transform extends TreeTransform {
53+
def phase = thisTransform
54+
55+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = new Transform
56+
57+
/** A map to cache mapping local methods to their direct counterparts.
58+
* A fresh map is created for each unit.
59+
*/
60+
private val directMeth = new mutable.HashMap[Symbol, Symbol]
61+
62+
/** @pre The type's final result type is an implicit function type `implicit Ts => R`.
63+
* @return The type of the `apply` member of `implicit Ts => R`.
64+
*/
65+
private def directInfo(info: Type)(implicit ctx: Context): Type = info match {
66+
case info: PolyType => info.derivedPolyType(resType = directInfo(info.resultType))
67+
case info: MethodType => info.derivedMethodType(resType = directInfo(info.resultType))
68+
case info: ExprType => directInfo(info.resultType)
69+
case info => info.member(nme.apply).info
70+
}
71+
72+
/** A new `m$direct` method to accompany the given method `m` */
73+
private def newDirectMethod(sym: Symbol)(implicit ctx: Context): Symbol =
74+
sym.copy(
75+
name = sym.name.directName,
76+
flags = sym.flags | Synthetic,
77+
info = directInfo(sym.info))
78+
79+
/** The direct method `m$direct` that accompanies the given method `m`.
80+
* Create one if it does not exist already.
81+
*/
82+
private def directMethod(sym: Symbol)(implicit ctx: Context): Symbol =
83+
if (sym.owner.isClass) {
84+
val direct = sym.owner.info.member(sym.name.directName)
85+
.suchThat(_.info matches directInfo(sym.info)).symbol
86+
if (direct.maybeOwner == sym.owner) direct
87+
else newDirectMethod(sym).enteredAfter(thisTransform)
88+
}
89+
else directMeth.getOrElseUpdate(sym, newDirectMethod(sym))
90+
91+
92+
/** Transform `qual.apply` occurrences according to rewrite rule (2) above */
93+
override def transformSelect(tree: Select)(implicit ctx: Context, info: TransformerInfo) =
94+
if (tree.name == nme.apply &&
95+
defn.isImplicitFunctionType(tree.qualifier.tpe.widen) &&
96+
tree.qualifier.symbol.is(Method, butNot = Accessor)) {
97+
def directQual(tree: Tree): Tree = tree match {
98+
case Apply(fn, args) => cpy.Apply(tree)(directQual(fn), args)
99+
case TypeApply(fn, args) => cpy.TypeApply(tree)(directQual(fn), args)
100+
case Block(stats, expr) => cpy.Block(tree)(stats, directQual(expr))
101+
case tree: RefTree =>
102+
cpy.Ref(tree)(tree.name.directName)
103+
.withType(directMethod(tree.symbol).termRef)
104+
}
105+
directQual(tree.qualifier)
106+
} else tree
107+
108+
/** Transform methods with implicit function type result according to rewrite rule (1) above */
109+
override def transformDefDef(mdef: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree = {
110+
val original = mdef.symbol
111+
if (defn.isImplicitFunctionType(original.info.finalResultType)) {
112+
val direct = directMethod(original)
113+
114+
def splitClosure(tree: Tree): (List[Type] => List[List[Tree]] => Tree, Tree) = tree match {
115+
case Block(Nil, expr) => splitClosure(expr)
116+
case Block((meth @ DefDef(nme.ANON_FUN, Nil, clparams :: Nil, _, _)) :: Nil, cl: Closure) =>
117+
val tparamSyms = mdef.tparams.map(_.symbol)
118+
val vparamSymss = mdef.vparamss.map(_.map(_.symbol))
119+
val clparamSyms = clparams.map(_.symbol)
120+
val remappedCore = (ts: List[Type]) => (prefss: List[List[Tree]]) =>
121+
meth.rhs
122+
.subst(tparamSyms ::: (vparamSymss.flatten ++ clparamSyms),
123+
ts.map(_.typeSymbol) ::: prefss.flatten.map(_.symbol))
124+
.changeOwnerAfter(original, direct, thisTransform)
125+
.changeOwnerAfter(meth.symbol, direct, thisTransform)
126+
val forwarder = ref(direct)
127+
.appliedToTypeTrees(tparamSyms.map(ref(_)))
128+
.appliedToArgss(vparamSymss.map(_.map(ref(_))) :+ clparamSyms.map(ref(_)))
129+
val fwdClosure = cpy.Block(tree)(cpy.DefDef(meth)(rhs = forwarder) :: Nil, cl)
130+
(remappedCore, fwdClosure)
131+
case EmptyTree =>
132+
(_ => _ => EmptyTree, EmptyTree)
133+
}
134+
135+
val (remappedCore, fwdClosure) = splitClosure(mdef.rhs)
136+
val originalDef = cpy.DefDef(mdef)(rhs = fwdClosure)
137+
val directDef = polyDefDef(direct.asTerm, remappedCore)
138+
Thicket(originalDef, directDef)
139+
}
140+
else mdef
141+
}
142+
}
143+
}

tests/run/implicitFuns.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,47 @@ object TransactionalExpansion {
211211
}
212212
}
213213

214+
object TransactionalAbstracted {
215+
type Transactional[T] = implicit Transaction => T
216+
217+
trait TransOps {
218+
def thisTransaction: Transactional[Transaction]
219+
def f1(x: Int): Transactional[Int]
220+
def f2(x: Int): Transactional[Int]
221+
def f3(x: Int): Transactional[Int]
222+
}
223+
224+
object TransOpsObj extends TransOps {
225+
226+
def thisTransaction: Transactional[Transaction] = implicitly[Transaction]
227+
228+
def f1(x: Int): Transactional[Int] = {
229+
thisTransaction.println(s"first step: $x")
230+
f2(x + 1)
231+
}
232+
def f2(x: Int): Transactional[Int] = {
233+
thisTransaction.println(s"second step: $x")
234+
f3(x * x)
235+
}
236+
def f3(x: Int): Transactional[Int] = {
237+
thisTransaction.println(s"third step: $x")
238+
if (x % 2 != 0) thisTransaction.abort()
239+
x
240+
}
241+
}
242+
243+
val transOps: TransOps = TransOpsObj
244+
245+
def transaction[T](op: Transactional[T]) = {
246+
implicit val trans: Transaction = new Transaction
247+
op
248+
trans.commit()
249+
}
250+
251+
def main(args: Array[String]) = {
252+
transaction {
253+
val res = transOps.f1(args.length)
254+
println(if (transOps.thisTransaction.isAborted) "aborted" else s"result: $res")
255+
}
256+
}
257+
}

0 commit comments

Comments
 (0)