Skip to content

Commit 9a69dbc

Browse files
committed
Don't treat inline closures specially.
Rely on call-by-name parameters instead. Fix "unused defs" logic so that referred-to cbn parameters are not eliminated.
1 parent 9802897 commit 9a69dbc

File tree

9 files changed

+184
-53
lines changed

9 files changed

+184
-53
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -810,16 +810,11 @@ object desugar {
810810
* ==>
811811
* def $anonfun(params) = body
812812
* Closure($anonfun)
813-
*
814-
* If `inlineable` is true, tag $anonfun with an @inline annotation.
815813
*/
816-
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree = TypeTree(), inlineable: Boolean)(implicit ctx: Context) = {
817-
var mods = synthetic | Artifact
818-
if (inlineable) mods |= Inline
814+
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree = TypeTree())(implicit ctx: Context) =
819815
Block(
820-
DefDef(nme.ANON_FUN, Nil, params :: Nil, tpt, body).withMods(mods),
816+
DefDef(nme.ANON_FUN, Nil, params :: Nil, tpt, body).withMods(synthetic | Artifact),
821817
Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
822-
}
823818

824819
/** If `nparams` == 1, expand partial function
825820
*

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,8 @@ trait Checking {
645645
case tp: TermRef if tp.symbol.is(InlineParam) => // ok
646646
case tp => tp.widenTermRefExpr match {
647647
case tp: ConstantType if exprPurity(tree) >= purityLevel => // ok
648-
case tp if defn.isFunctionType(tp) && exprPurity(tree) >= purityLevel => // ok
649648
case _ =>
650-
if (!ctx.erasedTypes) ctx.error(em"$what must be a constant expression or a function", tree.pos)
649+
if (!ctx.erasedTypes) ctx.error(em"$what must be a constant expression", tree.pos)
651650
}
652651
}
653652
}

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 131 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,31 @@ class Inliner(call: tpd.Tree, rhs: tpd.Tree)(implicit ctx: Context) {
218218
private def newSym(name: Name, flags: FlagSet, info: Type): Symbol =
219219
ctx.newSymbol(ctx.owner, name, flags, info, coord = call.pos)
220220

221+
/** A binding for the parameter of an inlined method. This is a `val` def for
222+
* by-value parameters and a `def` def for by-name parameters. `val` defs inherit
223+
* inline annotations from their parameters. The generated `def` is appended
224+
* to `bindingsBuf`.
225+
* @param name the name of the parameter
226+
* @param paramtp the type of the parameter
227+
* @param arg the argument corresponding to the parameter
228+
* @param bindingsBuf the buffer to which the definition should be appended
229+
*/
230+
private def paramBindingDef(name: Name, paramtp: Type, arg: Tree,
231+
bindingsBuf: mutable.ListBuffer[ValOrDefDef]): ValOrDefDef = {
232+
val argtpe = arg.tpe.dealias
233+
def isByName = paramtp.dealias.isInstanceOf[ExprType]
234+
val inlineFlag = if (paramtp.hasAnnotation(defn.InlineParamAnnot)) Inline else EmptyFlags
235+
val (bindingFlags, bindingType) =
236+
if (isByName) (Method, ExprType(argtpe.widen))
237+
else (inlineFlag, argtpe.widen)
238+
val boundSym = newSym(name, bindingFlags, bindingType).asTerm
239+
val binding =
240+
if (isByName) DefDef(boundSym, arg.changeOwner(ctx.owner, boundSym))
241+
else ValDef(boundSym, arg)
242+
bindingsBuf += binding
243+
binding
244+
}
245+
221246
/** Populate `paramBinding` and `bindingsBuf` by matching parameters with
222247
* corresponding arguments. `bindingbuf` will be further extended later by
223248
* proxies to this-references.
@@ -230,20 +255,9 @@ class Inliner(call: tpd.Tree, rhs: tpd.Tree)(implicit ctx: Context) {
230255
computeParamBindings(tp.resultType, Nil, argss)
231256
case tp: MethodType =>
232257
(tp.paramNames, tp.paramInfos, argss.head).zipped.foreach { (name, paramtp, arg) =>
233-
def isByName = paramtp.dealias.isInstanceOf[ExprType]
234258
paramBinding(name) = arg.tpe.dealias match {
235259
case _: SingletonType if isIdempotentExpr(arg) => arg.tpe
236-
case argtpe =>
237-
val inlineFlag = if (paramtp.hasAnnotation(defn.InlineParamAnnot)) Inline else EmptyFlags
238-
val (bindingFlags, bindingType) =
239-
if (isByName) (inlineFlag | Method, ExprType(argtpe.widen))
240-
else (inlineFlag, argtpe.widen)
241-
val boundSym = newSym(name, bindingFlags, bindingType).asTerm
242-
val binding =
243-
if (isByName) DefDef(boundSym, arg.changeOwner(ctx.owner, boundSym))
244-
else ValDef(boundSym, arg)
245-
bindingsBuf += binding
246-
boundSym.termRef
260+
case _ => paramBindingDef(name, paramtp, arg, bindingsBuf).symbol.termRef
247261
}
248262
}
249263
computeParamBindings(tp.resultType, targs, argss.tail)
@@ -265,7 +279,7 @@ class Inliner(call: tpd.Tree, rhs: tpd.Tree)(implicit ctx: Context) {
265279
* The proxy is not yet entered in `bindingsBuf`; that will come later.
266280
* 2. If given type refers to a parameter, make `paramProxy` refer to the entry stored
267281
* in `paramNames` under the parameter's name. This roundabout way to bind parameter
268-
* references to proxies is done because we not known a priori what the parameter
282+
* references to proxies is done because we don't know a priori what the parameter
269283
* references of a method are (we only know the method's type, but that contains TypeParamRefs
270284
* and MethodParams, not TypeRefs or TermRefs.
271285
*/
@@ -374,16 +388,15 @@ class Inliner(call: tpd.Tree, rhs: tpd.Tree)(implicit ctx: Context) {
374388
// The final expansion runs a typing pass over the inlined tree. See InlineTyper for details.
375389
val expansion1 = InlineTyper.typed(expansion, pt)(inlineCtx)
376390

377-
/** Does given definition bind a closure that will be inlined? */
378-
def bindsDeadInlineable(defn: ValOrDefDef) = Ident(defn.symbol.termRef) match {
379-
case InlineableArg(_) => !InlineTyper.retainedInlineables.contains(defn.symbol)
380-
case _ => false
381-
}
382-
383391
/** All bindings in `bindingsBuf` except bindings of inlineable closures */
384-
val bindings = bindingsBuf.toList.filterNot(bindsDeadInlineable).map(_.withPos(call.pos))
392+
val bindings = bindingsBuf.toList.map(_.withPos(call.pos))
393+
394+
inlining.println(i"original bindings = $bindings%\n%")
395+
inlining.println(i"original expansion = $expansion1")
385396

386-
tpd.Inlined(call, bindings, expansion1)
397+
val (finalBindings, finalExpansion) = dropUnusedDefs(bindings, expansion1)
398+
399+
tpd.Inlined(call, finalBindings, finalExpansion)
387400
}
388401
}
389402

@@ -414,8 +427,6 @@ class Inliner(call: tpd.Tree, rhs: tpd.Tree)(implicit ctx: Context) {
414427
*/
415428
private object InlineTyper extends ReTyper {
416429

417-
var retainedInlineables = Set[Symbol]()
418-
419430
override def ensureAccessible(tpe: Type, superAccess: Boolean, pos: Position)(implicit ctx: Context): Type = {
420431
tpe match {
421432
case tpe @ TypeRef(pre, _) if !tpe.symbol.isAccessibleFrom(pre, superAccess) =>
@@ -455,13 +466,103 @@ class Inliner(call: tpd.Tree, rhs: tpd.Tree)(implicit ctx: Context) {
455466
}
456467
}
457468

458-
override def typedApply(tree: untpd.Apply, pt: Type)(implicit ctx: Context) =
459-
tree.asInstanceOf[tpd.Tree] match {
460-
case Apply(Select(InlineableArg(closure(_, fn, _)), nme.apply), args) =>
461-
inlining.println(i"reducing $tree with closure $fn")
462-
typed(fn.appliedToArgs(args), pt)
463-
case _ =>
464-
super.typedApply(tree, pt)
469+
override def typedApply(tree: untpd.Apply, pt: Type)(implicit ctx: Context) = {
470+
471+
def betaReduce(tree: Tree) = tree match {
472+
case Apply(Select(cl @ closureDef(ddef), nme.apply), args) =>
473+
ddef.tpe.widen match {
474+
case mt: MethodType if ddef.vparamss.head.length == args.length =>
475+
val bindingsBuf = new mutable.ListBuffer[ValOrDefDef]
476+
val argSyms = (mt.paramNames, mt.paramInfos, args).zipped.map { (name, paramtp, arg) =>
477+
arg.tpe.dealias match {
478+
case ref @ TermRef(NoPrefix, _) => ref.symbol
479+
case _ => paramBindingDef(name, paramtp, arg, bindingsBuf).symbol
480+
}
481+
}
482+
val expander = new TreeTypeMap(
483+
oldOwners = ddef.symbol :: Nil,
484+
newOwners = ctx.owner :: Nil,
485+
substFrom = ddef.vparamss.head.map(_.symbol),
486+
substTo = argSyms)
487+
Block(bindingsBuf.toList, expander.transform(ddef.rhs))
488+
case _ => tree
489+
}
490+
case _ => tree
465491
}
492+
493+
betaReduce(super.typedApply(tree, pt))
494+
}
495+
}
496+
497+
/** Drop any side-effect-free bindings that are unused in expansion or other reachable bindings.
498+
* Inline def bindings that are used only once.
499+
*/
500+
def dropUnusedDefs(bindings: List[ValOrDefDef], tree: Tree)(implicit ctx: Context): (List[ValOrDefDef], Tree) = {
501+
val refCount = newMutableSymbolMap[Int]
502+
val bindingOfSym = newMutableSymbolMap[ValOrDefDef]
503+
def isInlineable(binding: ValOrDefDef) = binding match {
504+
case DefDef(_, Nil, Nil, _, _) => true
505+
case vdef @ ValDef(_, _, _) => isPureExpr(vdef.rhs)
506+
case _ => false
507+
}
508+
for (binding <- bindings if isInlineable(binding)) {
509+
refCount(binding.symbol) = 0
510+
bindingOfSym(binding.symbol) = binding
511+
}
512+
val countRefs = new TreeTraverser {
513+
override def traverse(t: Tree)(implicit ctx: Context) = {
514+
t match {
515+
case t: RefTree =>
516+
refCount.get(t.symbol) match {
517+
case Some(x) => refCount(t.symbol) = x + 1
518+
case none =>
519+
}
520+
case _: New | _: TypeTree =>
521+
//println(i"refcount ${t.tpe}")
522+
t.tpe.foreachPart {
523+
case ref: TermRef =>
524+
refCount.get(ref.symbol) match {
525+
case Some(x) => refCount(ref.symbol) = x + 2
526+
case none =>
527+
}
528+
case _ =>
529+
}
530+
case _ =>
531+
}
532+
traverseChildren(t)
533+
}
534+
}
535+
countRefs.traverse(tree)
536+
for (binding <- bindings) countRefs.traverse(binding.rhs)
537+
val inlineBindings = new TreeMap {
538+
override def transform(t: Tree)(implicit ctx: Context) =
539+
super.transform {
540+
t match {
541+
case t: RefTree =>
542+
val sym = t.symbol
543+
refCount.get(sym) match {
544+
case Some(1) if sym.is(Method) =>
545+
bindingOfSym(sym).rhs.changeOwner(sym, ctx.owner)
546+
case none => t
547+
}
548+
case _ => t
549+
}
550+
}
551+
}
552+
def retain(binding: ValOrDefDef) = refCount.get(binding.symbol) match {
553+
case Some(x) => x > 1 || x == 1 && !binding.symbol.is(Method)
554+
case none => true
555+
}
556+
val retained = bindings.filterConserve(retain)
557+
if (retained `eq` bindings) {
558+
//println(i"DONE\n${bindings}%\n% ;;;\n ${tree}")
559+
(bindings, tree)
560+
}
561+
else {
562+
val expanded = inlineBindings.transform(tree)
563+
//println(i"ref counts: ${refCount.toMap map { case (sym, count) => i"$sym -> $count" }}")
564+
//println(i"""MAPPING\n${bindings}%\n% ;;;\n ${tree} \n------->\n${retained}%\n%;;;\n ${expanded} """)
565+
dropUnusedDefs(retained, expanded)
566+
}
466567
}
467568
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ class Typer extends Namer
417417

418418
/** Check that a stable identifier pattern is indeed stable (SLS 8.1.5)
419419
*/
420-
private def checkStableIdentPattern(tree: Tree, pt: Type)(implicit ctx: Context): Tree = {
420+
private def checkStableIdentPattern(tree: Tree, pt: Type)(implicit ctx: Context): tree.type = {
421421
if (ctx.mode.is(Mode.Pattern) &&
422422
!tree.isType &&
423423
!pt.isInstanceOf[ApplyingProto] &&
@@ -915,8 +915,7 @@ class Typer extends Namer
915915
else cpy.ValDef(param)(
916916
tpt = untpd.TypeTree(
917917
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
918-
val inlineable = pt.hasAnnotation(defn.InlineParamAnnot)
919-
desugar.makeClosure(inferredParams, fnBody, resultTpt, inlineable)
918+
desugar.makeClosure(inferredParams, fnBody, resultTpt)
920919
}
921920
typed(desugared, pt)
922921
}

tests/run/i4431/quoted_1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import scala.quoted._
22

33
object Macros {
4-
inline def h(inline f: Int => String): String = ~ '(f(42))
4+
inline def h(f: => Int => String): String = ~ '(f(42))
55
}

tests/run/inlineByName.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
object Test {
2+
3+
class Range(from: Int, end: Int) {
4+
inline def foreach(op: => Int => Unit): Unit = {
5+
var i = from
6+
while (i < end) {
7+
op(i)
8+
i += 1
9+
}
10+
}
11+
}
12+
inline def twice(op: => Int => Unit): Unit = {
13+
op(1)
14+
op(2)
15+
}
16+
inline def thrice(op: => Unit): Unit = {
17+
op
18+
op
19+
op
20+
}
21+
22+
def main(args: Array[String]) = {
23+
var j = 0
24+
new Range(1, 10).foreach(j += _)
25+
assert(j == 45, j)
26+
twice { x => j = j - x }
27+
thrice { j = j + 1 }
28+
val f = new Range(1, 10).foreach
29+
f(j -= _)
30+
assert(j == 0, j)
31+
new Range(1, 10).foreach { i1 =>
32+
new Range(2, 11).foreach { i2 =>
33+
j += i1 * i2
34+
}
35+
}
36+
}
37+
}

tests/run/inlineForeach.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ object Test {
33
class Range(from: Int, end: Int) {
44

55
inline
6-
def foreach(inline op: Int => Unit): Unit = {
6+
def foreach(op: => Int => Unit): Unit = {
77
var i = from
88
while (i < end) {
99
op(i)
@@ -36,7 +36,7 @@ object Test {
3636
}
3737

3838
implicit class intArrayOps(arr: Array[Int]) {
39-
inline def foreach(inline op: Int => Unit): Unit = {
39+
inline def foreach(op: => Int => Unit): Unit = {
4040
var i = 0
4141
while (i < arr.length) {
4242
op(arr(i))

tests/run/inlinedAssign.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
object Test {
22

3-
inline def swap[T](x: T, inline x_= : T => Unit, y: T, inline y_= : T => Unit) = {
3+
inline def swap[T](x: T, x_= : => T => Unit, y: T, y_= : => T => Unit) = {
44
x_=(y)
55
y_=(x)
66
}

tests/run/lst/Lst.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Lst[+T](val elems: Any) extends AnyVal {
2525
def isEmpty = elems == null
2626
def nonEmpty = elems != null
2727

28-
inline def foreach(inline op: T => Unit): Unit = {
28+
inline def foreach(op: => T => Unit): Unit = {
2929
def sharedOp(x: T) = op(x)
3030
elems match {
3131
case null =>
@@ -39,7 +39,7 @@ class Lst[+T](val elems: Any) extends AnyVal {
3939
/** Like `foreach`, but completely inlines `op`, at the price of generating the code twice.
4040
* Should be used only of `op` is small
4141
*/
42-
inline def foreachInlined(inline op: T => Unit): Unit = elems match {
42+
inline def foreachInlined(op: => T => Unit): Unit = elems match {
4343
case null =>
4444
case elems: Arr => def elem(i: Int) = elems(i).asInstanceOf[T]
4545
var i = 0
@@ -60,7 +60,7 @@ class Lst[+T](val elems: Any) extends AnyVal {
6060
}
6161

6262
/** `f` is pulled out, not duplicated */
63-
inline def map[U](inline f: T => U): Lst[U] = {
63+
inline def map[U](f: => T => U): Lst[U] = {
6464
def op(x: T) = f(x)
6565
elems match {
6666
case null => Empty
@@ -144,7 +144,7 @@ class Lst[+T](val elems: Any) extends AnyVal {
144144
}
145145
def filterNot(p: T => Boolean): Lst[T] = filter(!p(_))
146146

147-
inline def exists(inline p: T => Boolean): Boolean = {
147+
inline def exists(p: => T => Boolean): Boolean = {
148148
def op(x: T) = p(x)
149149
elems match {
150150
case null => false
@@ -157,7 +157,7 @@ class Lst[+T](val elems: Any) extends AnyVal {
157157
}
158158
}
159159

160-
inline def forall(inline p: T => Boolean): Boolean = {
160+
inline def forall(p: => T => Boolean): Boolean = {
161161
def op(x: T) = p(x)
162162
elems match {
163163
case null => true
@@ -180,7 +180,7 @@ class Lst[+T](val elems: Any) extends AnyVal {
180180
elem == x
181181
}
182182

183-
inline def foldLeft[U](z: U)(inline f: (U, T) => U) = {
183+
inline def foldLeft[U](z: U)(f: => (U, T) => U) = {
184184
def op(x: U, y: T) = f(x, y)
185185
elems match {
186186
case null => z
@@ -194,7 +194,7 @@ class Lst[+T](val elems: Any) extends AnyVal {
194194
}
195195
}
196196

197-
inline def /: [U](z: U)(inline op: (U, T) => U) = foldLeft(z)(op)
197+
inline def /: [U](z: U)(op: => (U, T) => U) = foldLeft(z)(op)
198198

199199
def reduceLeft[U >: T](op: (U, U) => U) = elems match {
200200
case elems: Arr => def elem(i: Int) = elems(i).asInstanceOf[T]

0 commit comments

Comments
 (0)