Skip to content

Reuse beta reduction logic from BetaReduce #16390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName}
import config.Printers.inlining
import util.SimpleIdentityMap

import dotty.tools.dotc.transform.BetaReduce

import collection.mutable

/** A utility class offering methods for rewriting inlined code */
Expand Down Expand Up @@ -163,26 +165,12 @@ class InlineReducer(inliner: Inliner)(using Context):
*/
def betaReduce(tree: Tree)(using Context): Tree = tree match {
case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) =>
val bindingsBuf = new DefBuffer
val bindingsBuf = new mutable.ListBuffer[ValDef]
def recur(cl: Tree): Option[Tree] = cl match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
ddef.tpe.widen match
case mt: MethodType if ddef.paramss.head.length == args.length =>
val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) =>
arg.tpe.dealias match {
case ref @ TermRef(NoPrefix, _) => ref.symbol
case _ =>
paramBindingDef(name, paramtp, arg, bindingsBuf)(
using ctx.withSource(cl.source)
).symbol
}
}
val expander = new TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = ddef.paramss.head.map(_.symbol),
substTo = argSyms)
Some(expander.transform(ddef.rhs))
Some(BetaReduce.reduceApplication(ddef, args, bindingsBuf))
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr).map(cpy.Block(cl)(stats, _))
Expand All @@ -193,7 +181,7 @@ class InlineReducer(inliner: Inliner)(using Context):
case _ => None
recur(cl) match
case Some(reduced) =>
Block(bindingsBuf.toList, reduced).withSpan(tree.span)
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
case None =>
tree
case _ =>
Expand Down
20 changes: 13 additions & 7 deletions compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import Symbols._, Contexts._, Types._, Decorators._
import StdNames.nme
import ast.TreeTypeMap

import scala.collection.mutable.ListBuffer

/** Rewrite an application
*
* (((x1, ..., xn) => b): T)(y1, ..., yn)
Expand Down Expand Up @@ -70,9 +72,15 @@ object BetaReduce:
original
end apply

/** Beta-reduces a call to `ddef` with arguments `argSyms` */
/** Beta-reduces a call to `ddef` with arguments `args` */
def apply(ddef: DefDef, args: List[Tree])(using Context) =
val bindings = List.newBuilder[ValDef]
val bindings = new ListBuffer[ValDef]()
val expansion1 = reduceApplication(ddef, args, bindings)
val bindings1 = bindings.result()
seq(bindings1, expansion1)

/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
def reduceApplication(ddef: DefDef, args: List[Tree], bindings: ListBuffer[ValDef])(using Context): Tree =
val vparams = ddef.termParamss.iterator.flatten.toList
assert(args.hasSameLengthAs(vparams))
val argSyms =
Expand All @@ -84,7 +92,8 @@ object BetaReduce:
val flags = Synthetic | (param.symbol.flags & Erased)
val tpe = if arg.tpe.dealias.isInstanceOf[ConstantType] then arg.tpe.dealias else arg.tpe.widen
val binding = ValDef(newSymbol(ctx.owner, param.name, flags, tpe, coord = arg.span), arg).withSpan(arg.span)
bindings += binding
if !(tpe.isInstanceOf[ConstantType] && isPureExpr(arg)) then
bindings += binding
binding.symbol

val expansion = TreeTypeMap(
Expand All @@ -99,8 +108,5 @@ object BetaReduce:
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case _ => super.transform(tree)
}.transform(expansion)
val bindings1 =
bindings.result().filterNot(vdef => vdef.tpt.tpe.isInstanceOf[ConstantType] && isPureExpr(vdef.rhs))

seq(bindings1, expansion1)
end apply
expansion1
10 changes: 1 addition & 9 deletions compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -600,15 +600,7 @@ class InlineBytecodeTests extends DottyBytecodeTest {
val instructions = instructionsFromMethod(fun)
val expected = // TODO room for constant folding
List(
Op(ICONST_2),
VarOp(ISTORE, 1),
Op(ICONST_1),
VarOp(ISTORE, 2),
Op(ICONST_2),
VarOp(ILOAD, 2),
Op(IADD),
Op(ICONST_3),
Op(IADD),
IntOp(BIPUSH, 6),
Op(IRETURN),
)
assert(instructions == expected,
Expand Down
23 changes: 23 additions & 0 deletions tests/run/i16390.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
inline def cfor(inline body: Int => Unit): Unit =
var index = 0
while index < 3 do
body(index)
index = index + 1

@main def Test =
assert(test1() == test2(), (test1(), test2()))

def test1() =
val b = collection.mutable.ArrayBuffer.empty[() => Int]
cfor { x =>
b += (() => x)
}
b.map(_.apply()).toList

def test2() =
val b = collection.mutable.ArrayBuffer.empty[() => Int]
var index = 0
while index < 3 do
((x: Int) => b += (() => x)).apply(index)
index = index + 1
b.map(_.apply()).toList