Skip to content

Commit c5c5aa6

Browse files
authored
Beta-reduce directly applied PolymorphicFunction (#16623)
Beta-reduce directly applied PolymorphicFunction such as ```scala ([Z] => (arg: Z) => { def a: Z = arg; a }).apply[Int](2) ``` into ```scala type Z = Int val arg = 2 def a: Z = arg a ``` Apply this beta reduction in the `BetaReduce` phase and `Expr.betaReduce`. Also, refactor the beta-reduce logic to avoid code duplication. Fixes #15968
2 parents a2c89fb + db2d3eb commit c5c5aa6

File tree

12 files changed

+191
-91
lines changed

12 files changed

+191
-91
lines changed

compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName}
1212
import config.Printers.inlining
1313
import util.SimpleIdentityMap
1414

15-
import dotty.tools.dotc.transform.BetaReduce
16-
1715
import collection.mutable
1816

1917
/** A utility class offering methods for rewriting inlined code */
@@ -150,44 +148,6 @@ class InlineReducer(inliner: Inliner)(using Context):
150148
binding1.withSpan(call.span)
151149
}
152150

153-
/** Rewrite an application
154-
*
155-
* ((x1, ..., xn) => b)(e1, ..., en)
156-
*
157-
* to
158-
*
159-
* val/def x1 = e1; ...; val/def xn = en; b
160-
*
161-
* where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix
162-
* refs among the ei's directly without creating an intermediate binding.
163-
*
164-
* This variant of beta-reduction preserves the integrity of `Inlined` tree nodes.
165-
*/
166-
def betaReduce(tree: Tree)(using Context): Tree = tree match {
167-
case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) =>
168-
val bindingsBuf = new mutable.ListBuffer[ValDef]
169-
def recur(cl: Tree): Option[Tree] = cl match
170-
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
171-
ddef.tpe.widen match
172-
case mt: MethodType if ddef.paramss.head.length == args.length =>
173-
Some(BetaReduce.reduceApplication(ddef, args, bindingsBuf))
174-
case _ => None
175-
case Block(stats, expr) if stats.forall(isPureBinding) =>
176-
recur(expr).map(cpy.Block(cl)(stats, _))
177-
case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) =>
178-
recur(expr).map(cpy.Inlined(cl)(call, bindings, _))
179-
case Typed(expr, tpt) =>
180-
recur(expr)
181-
case _ => None
182-
recur(cl) match
183-
case Some(reduced) =>
184-
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
185-
case None =>
186-
tree
187-
case _ =>
188-
tree
189-
}
190-
191151
/** The result type of reducing a match. It consists optionally of a list of bindings
192152
* for the pattern-bound variables and the RHS of the selected case.
193153
* Returns `None` if no case was selected.

compiler/src/dotty/tools/dotc/inlines/Inliner.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import collection.mutable
2121
import reporting.trace
2222
import util.Spans.Span
2323
import dotty.tools.dotc.transform.Splicer
24+
import dotty.tools.dotc.transform.BetaReduce
2425
import quoted.QuoteUtils
2526
import scala.annotation.constructorOnly
2627

@@ -811,7 +812,7 @@ class Inliner(val call: tpd.Tree)(using Context):
811812
case Quoted(Spliced(inner)) => inner
812813
case _ => tree
813814
val locked = ctx.typerState.ownedVars
814-
val res = cancelQuotes(constToLiteral(betaReduce(super.typedApply(tree, pt)))) match {
815+
val res = cancelQuotes(constToLiteral(BetaReduce(super.typedApply(tree, pt)))) match {
815816
case res: Apply if res.symbol == defn.QuotedRuntime_exprSplice
816817
&& StagingContext.level == 0
817818
&& !hasInliningErrors =>
@@ -825,7 +826,7 @@ class Inliner(val call: tpd.Tree)(using Context):
825826

826827
override def typedTypeApply(tree: untpd.TypeApply, pt: Type)(using Context): Tree =
827828
val locked = ctx.typerState.ownedVars
828-
val tree1 = inlineIfNeeded(constToLiteral(betaReduce(super.typedTypeApply(tree, pt))), pt, locked)
829+
val tree1 = inlineIfNeeded(constToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked)
829830
if tree1.symbol.isQuote then
830831
ctx.compilationUnit.needsStaging = true
831832
tree1
@@ -1006,7 +1007,7 @@ class Inliner(val call: tpd.Tree)(using Context):
10061007
super.transform(t1)
10071008
case t: Apply =>
10081009
val t1 = super.transform(t)
1009-
if (t1 `eq` t) t else reducer.betaReduce(t1)
1010+
if (t1 `eq` t) t else BetaReduce(t1)
10101011
case Block(Nil, expr) =>
10111012
super.transform(expr)
10121013
case _ =>

compiler/src/dotty/tools/dotc/transform/BetaReduce.scala

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ import scala.collection.mutable.ListBuffer
1313

1414
/** Rewrite an application
1515
*
16-
* (((x1, ..., xn) => b): T)(y1, ..., yn)
16+
* (([X1, ..., Xm] => (x1, ..., xn) => b): T)[T1, ..., Tm](y1, ..., yn)
1717
*
1818
* where
1919
*
2020
* - all yi are pure references without a prefix
2121
* - the closure can also be contextual or erased, but cannot be a SAM type
22-
* _ the type ascription ...: T is optional
22+
* - the type parameters Xi and type arguments Ti are optional
23+
* - the type ascription ...: T is optional
2324
*
2425
* to
2526
*
@@ -38,51 +39,88 @@ class BetaReduce extends MiniPhase:
3839

3940
override def description: String = BetaReduce.description
4041

41-
override def transformApply(app: Apply)(using Context): Tree = app.fun match
42-
case Select(fn, nme.apply) if defn.isFunctionType(fn.tpe) =>
43-
val app1 = BetaReduce(app, fn, app.args)
44-
if app1 ne app then report.log(i"beta reduce $app -> $app1")
45-
app1
46-
case _ =>
47-
app
48-
42+
override def transformApply(app: Apply)(using Context): Tree =
43+
val app1 = BetaReduce(app)
44+
if app1 ne app then report.log(i"beta reduce $app -> $app1")
45+
app1
4946

5047
object BetaReduce:
5148
import ast.tpd._
5249

5350
val name: String = "betaReduce"
5451
val description: String = "reduce closure applications"
5552

56-
/** Beta-reduces a call to `fn` with arguments `argSyms` or returns `tree` */
57-
def apply(original: Tree, fn: Tree, args: List[Tree])(using Context): Tree =
58-
fn match
59-
case Typed(expr, _) =>
60-
BetaReduce(original, expr, args)
61-
case Block((anonFun: DefDef) :: Nil, closure: Closure) =>
62-
BetaReduce(anonFun, args)
63-
case Block(stats, expr) =>
64-
val tree = BetaReduce(original, expr, args)
65-
if tree eq original then original
66-
else cpy.Block(fn)(stats, tree)
67-
case Inlined(call, bindings, expr) =>
68-
val tree = BetaReduce(original, expr, args)
69-
if tree eq original then original
70-
else cpy.Inlined(fn)(call, bindings, tree)
53+
/** Rewrite an application
54+
*
55+
* ((x1, ..., xn) => b)(e1, ..., en)
56+
*
57+
* to
58+
*
59+
* val/def x1 = e1; ...; val/def xn = en; b
60+
*
61+
* where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix
62+
* refs among the ei's directly without creating an intermediate binding.
63+
*
64+
* Similarly, rewrites type applications
65+
*
66+
* ([X1, ..., Xm] => (x1, ..., xn) => b).apply[T1, .., Tm](e1, ..., en)
67+
*
68+
* to
69+
*
70+
* type X1 = T1; ...; type Xm = Tm;val/def x1 = e1; ...; val/def xn = en; b
71+
*
72+
* This beta-reduction preserves the integrity of `Inlined` tree nodes.
73+
*/
74+
def apply(tree: Tree)(using Context): Tree =
75+
val bindingsBuf = new ListBuffer[DefTree]
76+
def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match
77+
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
78+
Some(reduceApplication(ddef, argss, bindingsBuf))
79+
case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty =>
80+
template.body match
81+
case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf))
82+
case _ => None
83+
case Block(stats, expr) if stats.forall(isPureBinding) =>
84+
recur(expr, argss).map(cpy.Block(fn)(stats, _))
85+
case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) =>
86+
recur(expr, argss).map(cpy.Inlined(fn)(call, bindings, _))
87+
case Typed(expr, tpt) =>
88+
recur(expr, argss)
89+
case TypeApply(Select(expr, nme.asInstanceOfPM), List(tpt)) =>
90+
recur(expr, argss)
91+
case _ => None
92+
tree match
93+
case Apply(Select(fn, nme.apply), args) if defn.isFunctionType(fn.tpe) =>
94+
recur(fn, List(args)) match
95+
case Some(reduced) =>
96+
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
97+
case None =>
98+
tree
99+
case Apply(TypeApply(Select(fn, nme.apply), targs), args) if fn.tpe.typeSymbol eq dotc.core.Symbols.defn.PolyFunctionClass =>
100+
recur(fn, List(targs, args)) match
101+
case Some(reduced) =>
102+
seq(bindingsBuf.result(), reduced).withSpan(tree.span)
103+
case None =>
104+
tree
71105
case _ =>
72-
original
73-
end apply
74-
75-
/** Beta-reduces a call to `ddef` with arguments `args` */
76-
def apply(ddef: DefDef, args: List[Tree])(using Context) =
77-
val bindings = new ListBuffer[ValDef]()
78-
val expansion1 = reduceApplication(ddef, args, bindings)
79-
val bindings1 = bindings.result()
80-
seq(bindings1, expansion1)
106+
tree
81107

82108
/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
83-
def reduceApplication(ddef: DefDef, args: List[Tree], bindings: ListBuffer[ValDef])(using Context): Tree =
84-
val vparams = ddef.termParamss.iterator.flatten.toList
85-
assert(args.hasSameLengthAs(vparams))
109+
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree =
110+
val (targs, args) = argss.flatten.partition(_.isType)
111+
val tparams = ddef.leadingTypeParams
112+
val vparams = ddef.termParamss.flatten
113+
114+
val targSyms =
115+
for (targ, tparam) <- targs.zip(tparams) yield
116+
targ.tpe.dealias match
117+
case ref @ TypeRef(NoPrefix, _) =>
118+
ref.symbol
119+
case _ =>
120+
val binding = TypeDef(newSymbol(ctx.owner, tparam.name, EmptyFlags, targ.tpe, coord = targ.span)).withSpan(targ.span)
121+
bindings += binding
122+
binding.symbol
123+
86124
val argSyms =
87125
for (arg, param) <- args.zip(vparams) yield
88126
arg.tpe.dealias match
@@ -99,8 +137,8 @@ object BetaReduce:
99137
val expansion = TreeTypeMap(
100138
oldOwners = ddef.symbol :: Nil,
101139
newOwners = ctx.owner :: Nil,
102-
substFrom = vparams.map(_.symbol),
103-
substTo = argSyms
140+
substFrom = (tparams ::: vparams).map(_.symbol),
141+
substTo = targSyms ::: argSyms
104142
).transform(ddef.rhs)
105143

106144
val expansion1 = new TreeMap {

compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import Symbols._, Contexts._, Types._, Decorators._
88
import NameOps._
99
import Names._
1010

11+
import scala.collection.mutable.ListBuffer
12+
1113
/** Rewrite an application
1214
*
1315
* {new { def unapply(x0: X0)(x1: X1,..., xn: Xn) = b }}.unapply(y0)(y1, ..., yn)
@@ -38,7 +40,7 @@ class InlinePatterns extends MiniPhase:
3840
if app.symbol.name.isUnapplyName && !app.tpe.isInstanceOf[MethodicType] then
3941
app match
4042
case App(Select(fn, name), argss) =>
41-
val app1 = betaReduce(app, fn, name, argss.flatten)
43+
val app1 = betaReduce(app, fn, name, argss)
4244
if app1 ne app then report.log(i"beta reduce $app -> $app1")
4345
app1
4446
case _ =>
@@ -51,11 +53,16 @@ class InlinePatterns extends MiniPhase:
5153
case Apply(App(fn, argss), args) => (fn, argss :+ args)
5254
case _ => (app, Nil)
5355

54-
private def betaReduce(tree: Apply, fn: Tree, name: Name, args: List[Tree])(using Context): Tree =
56+
// TODO merge with BetaReduce.scala
57+
private def betaReduce(tree: Apply, fn: Tree, name: Name, argss: List[List[Tree]])(using Context): Tree =
5558
fn match
5659
case Block(TypeDef(_, template: Template) :: Nil, Apply(Select(New(_),_), Nil)) if template.constr.rhs.isEmpty =>
5760
template.body match
58-
case List(ddef @ DefDef(`name`, _, _, _)) => BetaReduce(ddef, args)
61+
case List(ddef @ DefDef(`name`, _, _, _)) =>
62+
val bindings = new ListBuffer[DefTree]()
63+
val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings)
64+
val bindings1 = bindings.result()
65+
seq(bindings1, expansion1)
5966
case _ => tree
6067
case _ => tree
6168

compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,10 @@ object PickleQuotes {
322322
}
323323
val Block(List(ddef: DefDef), _) = splice: @unchecked
324324
// TODO: beta reduce inner closure? Or wait until BetaReduce phase?
325-
BetaReduce(ddef, spliceArgs).select(nme.apply).appliedTo(args(2).asInstance(quotesType))
325+
BetaReduce(
326+
splice
327+
.select(nme.apply).appliedToArgs(spliceArgs))
328+
.select(nme.apply).appliedTo(args(2).asInstance(quotesType))
326329
}
327330
CaseDef(Literal(Constant(idx)), EmptyTree, rhs)
328331
}

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,16 +371,15 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
371371
object Term extends TermModule:
372372
def betaReduce(tree: Term): Option[Term] =
373373
tree match
374-
case app @ tpd.Apply(tpd.Select(fn, nme.apply), args) if dotc.core.Symbols.defn.isFunctionType(fn.tpe) =>
375-
val app1 = dotc.transform.BetaReduce(app, fn, args)
376-
if app1 eq app then None
377-
else Some(app1.withSpan(tree.span))
378374
case tpd.Block(Nil, expr) =>
379375
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
380376
case tpd.Inlined(_, Nil, expr) =>
381377
betaReduce(expr)
382378
case _ =>
383-
None
379+
val tree1 = dotc.transform.BetaReduce(tree)
380+
if tree1 eq tree then None
381+
else Some(tree1.withSpan(tree.span))
382+
384383
end Term
385384

386385
given TermMethods: TermMethods with

compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,63 @@ class InlineBytecodeTests extends DottyBytecodeTest {
578578
}
579579
}
580580

581+
@Test def beta_reduce_polymorphic_function = {
582+
val source = """class Test:
583+
| def test =
584+
| ([Z] => (arg: Z) => { val a: Z = arg; a }).apply[Int](2)
585+
""".stripMargin
586+
587+
checkBCode(source) { dir =>
588+
val clsIn = dir.lookupName("Test.class", directory = false).input
589+
val clsNode = loadClassNode(clsIn)
590+
591+
val fun = getMethod(clsNode, "test")
592+
val instructions = instructionsFromMethod(fun)
593+
val expected =
594+
List(
595+
Op(ICONST_2),
596+
VarOp(ISTORE, 1),
597+
VarOp(ILOAD, 1),
598+
Op(IRETURN)
599+
)
600+
601+
assert(instructions == expected,
602+
"`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected))
603+
604+
}
605+
}
606+
607+
@Test def beta_reduce_function_of_opaque_types = {
608+
val source = """object foo:
609+
| opaque type T = Int
610+
| inline def apply(inline op: T => T): T = op(2)
611+
|
612+
|class Test:
613+
| def test = foo { n => n }
614+
""".stripMargin
615+
616+
checkBCode(source) { dir =>
617+
val clsIn = dir.lookupName("Test.class", directory = false).input
618+
val clsNode = loadClassNode(clsIn)
619+
620+
val fun = getMethod(clsNode, "test")
621+
val instructions = instructionsFromMethod(fun)
622+
val expected =
623+
List(
624+
Field(GETSTATIC, "foo$", "MODULE$", "Lfoo$;"),
625+
VarOp(ASTORE, 1),
626+
VarOp(ALOAD, 1),
627+
VarOp(ASTORE, 2),
628+
Op(ICONST_2),
629+
Op(IRETURN),
630+
)
631+
632+
assert(instructions == expected,
633+
"`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected))
634+
635+
}
636+
}
637+
581638
@Test def i9456 = {
582639
val source = """class Foo {
583640
| def test: Int = inline2(inline1(2.+))

tests/run-macros/i15968.check

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
type Z = java.lang.String
3+
"foo".toString()
4+
}
5+
"foo".toString()

tests/run-macros/i15968/Macro_1.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import scala.quoted.*
2+
3+
inline def macroPolyFun[A](inline arg: A, inline f: [Z] => Z => String): String =
4+
${ macroPolyFunImpl[A]('arg, 'f) }
5+
6+
private def macroPolyFunImpl[A: Type](arg: Expr[A], f: Expr[[Z] => Z => String])(using Quotes): Expr[String] =
7+
Expr(Expr.betaReduce('{ $f($arg) }).show)
8+
9+
10+
inline def macroFun[A](inline arg: A, inline f: A => String): String =
11+
${ macroFunImpl[A]('arg, 'f) }
12+
13+
private def macroFunImpl[A: Type](arg: Expr[A], f: Expr[A => String])(using Quotes): Expr[String] =
14+
Expr(Expr.betaReduce('{ $f($arg) }).show)
15+

0 commit comments

Comments
 (0)