Skip to content

Beta-reduce directly applied PolymorphicFunction #16623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 12, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName}
import config.Printers.inlining
import util.SimpleIdentityMap

import dotty.tools.dotc.transform.BetaReduce

import collection.mutable

/** A utility class offering methods for rewriting inlined code */
Expand Down Expand Up @@ -150,44 +148,6 @@ class InlineReducer(inliner: Inliner)(using Context):
binding1.withSpan(call.span)
}

/** Rewrite an application
*
* ((x1, ..., xn) => b)(e1, ..., en)
*
* to
*
* val/def x1 = e1; ...; val/def xn = en; b
*
* where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix
* refs among the ei's directly without creating an intermediate binding.
*
* This variant of beta-reduction preserves the integrity of `Inlined` tree nodes.
*/
def betaReduce(tree: Tree)(using Context): Tree = tree match {
case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) =>
val bindingsBuf = new mutable.ListBuffer[ValDef]
def recur(cl: Tree): Option[Tree] = cl match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
ddef.tpe.widen match
case mt: MethodType if ddef.paramss.head.length == args.length =>
Some(BetaReduce.reduceApplication(ddef, args, bindingsBuf))
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr).map(cpy.Block(cl)(stats, _))
case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) =>
recur(expr).map(cpy.Inlined(cl)(call, bindings, _))
case Typed(expr, tpt) =>
recur(expr)
case _ => None
recur(cl) match
case Some(reduced) =>
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
case None =>
tree
case _ =>
tree
}

/** The result type of reducing a match. It consists optionally of a list of bindings
* for the pattern-bound variables and the RHS of the selected case.
* Returns `None` if no case was selected.
Expand Down
7 changes: 4 additions & 3 deletions compiler/src/dotty/tools/dotc/inlines/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import collection.mutable
import reporting.trace
import util.Spans.Span
import dotty.tools.dotc.transform.Splicer
import dotty.tools.dotc.transform.BetaReduce
import quoted.QuoteUtils
import scala.annotation.constructorOnly

Expand Down Expand Up @@ -811,7 +812,7 @@ class Inliner(val call: tpd.Tree)(using Context):
case Quoted(Spliced(inner)) => inner
case _ => tree
val locked = ctx.typerState.ownedVars
val res = cancelQuotes(constToLiteral(betaReduce(super.typedApply(tree, pt)))) match {
val res = cancelQuotes(constToLiteral(BetaReduce(super.typedApply(tree, pt)))) match {
case res: Apply if res.symbol == defn.QuotedRuntime_exprSplice
&& StagingContext.level == 0
&& !hasInliningErrors =>
Expand All @@ -824,7 +825,7 @@ class Inliner(val call: tpd.Tree)(using Context):

override def typedTypeApply(tree: untpd.TypeApply, pt: Type)(using Context): Tree =
val locked = ctx.typerState.ownedVars
val tree1 = inlineIfNeeded(constToLiteral(betaReduce(super.typedTypeApply(tree, pt))), pt, locked)
val tree1 = inlineIfNeeded(constToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked)
if tree1.symbol.isQuote then
ctx.compilationUnit.needsStaging = true
tree1
Expand Down Expand Up @@ -1005,7 +1006,7 @@ class Inliner(val call: tpd.Tree)(using Context):
super.transform(t1)
case t: Apply =>
val t1 = super.transform(t)
if (t1 `eq` t) t else reducer.betaReduce(t1)
if (t1 `eq` t) t else BetaReduce(t1)
case Block(Nil, expr) =>
super.transform(expr)
case _ =>
Expand Down
114 changes: 75 additions & 39 deletions compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ import scala.collection.mutable.ListBuffer

/** Rewrite an application
*
* (((x1, ..., xn) => b): T)(y1, ..., yn)
* (([X1, ..., Xm] => (x1, ..., xn) => b): T)[T1, ..., Tm](y1, ..., yn)
*
* where
*
* - all yi are pure references without a prefix
* - the closure can also be contextual or erased, but cannot be a SAM type
* _ the type ascription ...: T is optional
* - the type parameters Xi and type arguments Ti are optional
* - the type ascription ...: T is optional
*
* to
*
Expand All @@ -38,51 +39,86 @@ class BetaReduce extends MiniPhase:

override def description: String = BetaReduce.description

override def transformApply(app: Apply)(using Context): Tree = app.fun match
case Select(fn, nme.apply) if defn.isFunctionType(fn.tpe) =>
val app1 = BetaReduce(app, fn, app.args)
if app1 ne app then report.log(i"beta reduce $app -> $app1")
app1
case _ =>
app

override def transformApply(app: Apply)(using Context): Tree =
val app1 = BetaReduce(app)
if app1 ne app then report.log(i"beta reduce $app -> $app1")
app1

object BetaReduce:
import ast.tpd._

val name: String = "betaReduce"
val description: String = "reduce closure applications"

/** Beta-reduces a call to `fn` with arguments `argSyms` or returns `tree` */
def apply(original: Tree, fn: Tree, args: List[Tree])(using Context): Tree =
fn match
case Typed(expr, _) =>
BetaReduce(original, expr, args)
case Block((anonFun: DefDef) :: Nil, closure: Closure) =>
BetaReduce(anonFun, args)
case Block(stats, expr) =>
val tree = BetaReduce(original, expr, args)
if tree eq original then original
else cpy.Block(fn)(stats, tree)
case Inlined(call, bindings, expr) =>
val tree = BetaReduce(original, expr, args)
if tree eq original then original
else cpy.Inlined(fn)(call, bindings, tree)
/** Rewrite an application
*
* ((x1, ..., xn) => b)(e1, ..., en)
*
* to
*
* val/def x1 = e1; ...; val/def xn = en; b
*
* where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix
* refs among the ei's directly without creating an intermediate binding.
*
* Similarly, rewrites type applications
*
* ([X1, ..., Xm] => (x1, ..., xn) => b).apply[T1, .., Tm](e1, ..., en)
*
* to
*
* type X1 = T1; ...; type Xm = Tm;val/def x1 = e1; ...; val/def xn = en; b
*
* This beta-reduction preserves the integrity of `Inlined` tree nodes.
*/
def apply(tree: Tree)(using Context): Tree =
val bindingsBuf = new ListBuffer[DefTree]
def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
Some(reduceApplication(ddef, argss, bindingsBuf))
case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty =>
template.body match
case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf))
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr, argss).map(cpy.Block(fn)(stats, _))
case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) =>
recur(expr, argss).map(cpy.Inlined(fn)(call, bindings, _))
case Typed(expr, tpt) =>
recur(expr, argss)
case _ => None
tree match
case Apply(Select(fn, nme.apply), args) if defn.isFunctionType(fn.tpe) =>
recur(fn, List(args)) match
case Some(reduced) =>
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
case None =>
tree
case Apply(TypeApply(Select(fn, nme.apply), targs), args) if fn.tpe.typeSymbol eq dotc.core.Symbols.defn.PolyFunctionClass =>
recur(fn, List(targs, args)) match
case Some(reduced) =>
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
case None =>
tree
case _ =>
original
end apply

/** Beta-reduces a call to `ddef` with arguments `args` */
def apply(ddef: DefDef, args: List[Tree])(using Context) =
val bindings = new ListBuffer[ValDef]()
val expansion1 = reduceApplication(ddef, args, bindings)
val bindings1 = bindings.result()
seq(bindings1, expansion1)
tree

/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
def reduceApplication(ddef: DefDef, args: List[Tree], bindings: ListBuffer[ValDef])(using Context): Tree =
val vparams = ddef.termParamss.iterator.flatten.toList
assert(args.hasSameLengthAs(vparams))
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree =
val (targs, args) = argss.flatten.partition(_.isType)
val tparams = ddef.leadingTypeParams
val vparams = ddef.termParamss.flatten

val targSyms =
for (targ, tparam) <- targs.zip(tparams) yield
targ.tpe.dealias match
case ref @ TypeRef(NoPrefix, _) =>
ref.symbol
case _ =>
val binding = TypeDef(newSymbol(ctx.owner, tparam.name, EmptyFlags, targ.tpe, coord = targ.span)).withSpan(targ.span)
bindings += binding
binding.symbol

val argSyms =
for (arg, param) <- args.zip(vparams) yield
arg.tpe.dealias match
Expand All @@ -99,8 +135,8 @@ object BetaReduce:
val expansion = TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = vparams.map(_.symbol),
substTo = argSyms
substFrom = (tparams ::: vparams).map(_.symbol),
substTo = targSyms ::: argSyms
).transform(ddef.rhs)

val expansion1 = new TreeMap {
Expand Down
13 changes: 10 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import Symbols._, Contexts._, Types._, Decorators._
import NameOps._
import Names._

import scala.collection.mutable.ListBuffer

/** Rewrite an application
*
* {new { def unapply(x0: X0)(x1: X1,..., xn: Xn) = b }}.unapply(y0)(y1, ..., yn)
Expand Down Expand Up @@ -38,7 +40,7 @@ class InlinePatterns extends MiniPhase:
if app.symbol.name.isUnapplyName && !app.tpe.isInstanceOf[MethodicType] then
app match
case App(Select(fn, name), argss) =>
val app1 = betaReduce(app, fn, name, argss.flatten)
val app1 = betaReduce(app, fn, name, argss)
if app1 ne app then report.log(i"beta reduce $app -> $app1")
app1
case _ =>
Expand All @@ -51,11 +53,16 @@ class InlinePatterns extends MiniPhase:
case Apply(App(fn, argss), args) => (fn, argss :+ args)
case _ => (app, Nil)

private def betaReduce(tree: Apply, fn: Tree, name: Name, args: List[Tree])(using Context): Tree =
// TODO merge with BetaReduce.scala
private def betaReduce(tree: Apply, fn: Tree, name: Name, argss: List[List[Tree]])(using Context): Tree =
fn match
case Block(TypeDef(_, template: Template) :: Nil, Apply(Select(New(_),_), Nil)) if template.constr.rhs.isEmpty =>
template.body match
case List(ddef @ DefDef(`name`, _, _, _)) => BetaReduce(ddef, args)
case List(ddef @ DefDef(`name`, _, _, _)) =>
val bindings = new ListBuffer[DefTree]()
val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings)
val bindings1 = bindings.result()
seq(bindings1, expansion1)
case _ => tree
case _ => tree

Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ object PickleQuotes {
}
val Block(List(ddef: DefDef), _) = splice: @unchecked
// TODO: beta reduce inner closure? Or wait until BetaReduce phase?
BetaReduce(ddef, spliceArgs).select(nme.apply).appliedTo(args(2).asInstance(quotesType))
BetaReduce(
splice
.select(nme.apply).appliedToArgs(spliceArgs))
.select(nme.apply).appliedTo(args(2).asInstance(quotesType))
}
CaseDef(Literal(Constant(idx)), EmptyTree, rhs)
}
Expand Down
9 changes: 4 additions & 5 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,15 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
object Term extends TermModule:
def betaReduce(tree: Term): Option[Term] =
tree match
case app @ tpd.Apply(tpd.Select(fn, nme.apply), args) if dotc.core.Symbols.defn.isFunctionType(fn.tpe) =>
val app1 = dotc.transform.BetaReduce(app, fn, args)
if app1 eq app then None
else Some(app1.withSpan(tree.span))
case tpd.Block(Nil, expr) =>
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
case tpd.Inlined(_, Nil, expr) =>
betaReduce(expr)
case _ =>
None
val tree1 = dotc.transform.BetaReduce(tree)
if tree1 eq tree then None
else Some(tree1.withSpan(tree.span))

end Term

given TermMethods: TermMethods with
Expand Down
26 changes: 26 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,32 @@ class InlineBytecodeTests extends DottyBytecodeTest {
}
}

@Test def beta_reduce_polymorphic_function = {
val source = """class Test:
| def test =
| ([Z] => (arg: Z) => { val a: Z = arg; a }).apply[Int](2)
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)

val fun = getMethod(clsNode, "test")
val instructions = instructionsFromMethod(fun)
val expected =
List(
Op(ICONST_2),
VarOp(ISTORE, 1),
VarOp(ILOAD, 1),
Op(IRETURN)
)

assert(instructions == expected,
"`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected))

}
}

@Test def i9456 = {
val source = """class Foo {
| def test: Int = inline2(inline1(2.+))
Expand Down
5 changes: 5 additions & 0 deletions tests/run-macros/i15968.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
type Z = java.lang.String
"foo".toString()
}
"foo".toString()
15 changes: 15 additions & 0 deletions tests/run-macros/i15968/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import scala.quoted.*

inline def macroPolyFun[A](inline arg: A, inline f: [Z] => Z => String): String =
${ macroPolyFunImpl[A]('arg, 'f) }

private def macroPolyFunImpl[A: Type](arg: Expr[A], f: Expr[[Z] => Z => String])(using Quotes): Expr[String] =
Expr(Expr.betaReduce('{ $f($arg) }).show)


inline def macroFun[A](inline arg: A, inline f: A => String): String =
${ macroFunImpl[A]('arg, 'f) }

private def macroFunImpl[A: Type](arg: Expr[A], f: Expr[A => String])(using Quotes): Expr[String] =
Expr(Expr.betaReduce('{ $f($arg) }).show)

3 changes: 3 additions & 0 deletions tests/run-macros/i15968/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@main def Test: Unit =
println(macroPolyFun("foo", [Z] => (arg: Z) => arg.toString))
println(macroFun("foo", arg => arg.toString))
7 changes: 7 additions & 0 deletions tests/run-macros/inline-beta-reduce-polyfunction.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
type X = Int
{
println(1)
1
}
}
5 changes: 5 additions & 0 deletions tests/run-macros/inline-beta-reduce-polyfunction.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transparent inline def foo(inline f: [X] => X => X): Int = f[Int](1)

@main def Test: Unit =
val code = compiletime.codeOf(foo([X] => (x: X) => { println(x); x }))
println(code)