diff --git a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala index 42e86b71eff8..547160340238 100644 --- a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala +++ b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala @@ -12,8 +12,6 @@ import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName} import config.Printers.inlining import util.SimpleIdentityMap -import dotty.tools.dotc.transform.BetaReduce - import collection.mutable /** A utility class offering methods for rewriting inlined code */ @@ -150,44 +148,6 @@ class InlineReducer(inliner: Inliner)(using Context): binding1.withSpan(call.span) } - /** Rewrite an application - * - * ((x1, ..., xn) => b)(e1, ..., en) - * - * to - * - * val/def x1 = e1; ...; val/def xn = en; b - * - * where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix - * refs among the ei's directly without creating an intermediate binding. - * - * This variant of beta-reduction preserves the integrity of `Inlined` tree nodes. - */ - def betaReduce(tree: Tree)(using Context): Tree = tree match { - case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) => - val bindingsBuf = new mutable.ListBuffer[ValDef] - def recur(cl: Tree): Option[Tree] = cl match - case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol => - ddef.tpe.widen match - case mt: MethodType if ddef.paramss.head.length == args.length => - Some(BetaReduce.reduceApplication(ddef, args, bindingsBuf)) - case _ => None - case Block(stats, expr) if stats.forall(isPureBinding) => - recur(expr).map(cpy.Block(cl)(stats, _)) - case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) => - recur(expr).map(cpy.Inlined(cl)(call, bindings, _)) - case Typed(expr, tpt) => - recur(expr) - case _ => None - recur(cl) match - case Some(reduced) => - seq(bindingsBuf.result(), reduced).withSpan(tree.span) - case None => - tree - case _ => - tree - } - /** The result type of reducing a match. It consists optionally of a list of bindings * for the pattern-bound variables and the RHS of the selected case. * Returns `None` if no case was selected. diff --git a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala index c71648984664..40b0b383a4e3 100644 --- a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala @@ -21,6 +21,7 @@ import collection.mutable import reporting.trace import util.Spans.Span import dotty.tools.dotc.transform.Splicer +import dotty.tools.dotc.transform.BetaReduce import quoted.QuoteUtils import scala.annotation.constructorOnly @@ -811,7 +812,7 @@ class Inliner(val call: tpd.Tree)(using Context): case Quoted(Spliced(inner)) => inner case _ => tree val locked = ctx.typerState.ownedVars - val res = cancelQuotes(constToLiteral(betaReduce(super.typedApply(tree, pt)))) match { + val res = cancelQuotes(constToLiteral(BetaReduce(super.typedApply(tree, pt)))) match { case res: Apply if res.symbol == defn.QuotedRuntime_exprSplice && StagingContext.level == 0 && !hasInliningErrors => @@ -824,7 +825,7 @@ class Inliner(val call: tpd.Tree)(using Context): override def typedTypeApply(tree: untpd.TypeApply, pt: Type)(using Context): Tree = val locked = ctx.typerState.ownedVars - val tree1 = inlineIfNeeded(constToLiteral(betaReduce(super.typedTypeApply(tree, pt))), pt, locked) + val tree1 = inlineIfNeeded(constToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked) if tree1.symbol.isQuote then ctx.compilationUnit.needsStaging = true tree1 @@ -1005,7 +1006,7 @@ class Inliner(val call: tpd.Tree)(using Context): super.transform(t1) case t: Apply => val t1 = super.transform(t) - if (t1 `eq` t) t else reducer.betaReduce(t1) + if (t1 `eq` t) t else BetaReduce(t1) case Block(Nil, expr) => super.transform(expr) case _ => diff --git a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala index 7ac3dc972ad1..97dc4697db6d 100644 --- a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala @@ -13,13 +13,14 @@ import scala.collection.mutable.ListBuffer /** Rewrite an application * - * (((x1, ..., xn) => b): T)(y1, ..., yn) + * (([X1, ..., Xm] => (x1, ..., xn) => b): T)[T1, ..., Tm](y1, ..., yn) * * where * * - all yi are pure references without a prefix * - the closure can also be contextual or erased, but cannot be a SAM type - * _ the type ascription ...: T is optional + * - the type parameters Xi and type arguments Ti are optional + * - the type ascription ...: T is optional * * to * @@ -38,14 +39,10 @@ class BetaReduce extends MiniPhase: override def description: String = BetaReduce.description - override def transformApply(app: Apply)(using Context): Tree = app.fun match - case Select(fn, nme.apply) if defn.isFunctionType(fn.tpe) => - val app1 = BetaReduce(app, fn, app.args) - if app1 ne app then report.log(i"beta reduce $app -> $app1") - app1 - case _ => - app - + override def transformApply(app: Apply)(using Context): Tree = + val app1 = BetaReduce(app) + if app1 ne app then report.log(i"beta reduce $app -> $app1") + app1 object BetaReduce: import ast.tpd._ @@ -53,36 +50,77 @@ object BetaReduce: val name: String = "betaReduce" val description: String = "reduce closure applications" - /** Beta-reduces a call to `fn` with arguments `argSyms` or returns `tree` */ - def apply(original: Tree, fn: Tree, args: List[Tree])(using Context): Tree = - fn match - case Typed(expr, _) => - BetaReduce(original, expr, args) - case Block((anonFun: DefDef) :: Nil, closure: Closure) => - BetaReduce(anonFun, args) - case Block(stats, expr) => - val tree = BetaReduce(original, expr, args) - if tree eq original then original - else cpy.Block(fn)(stats, tree) - case Inlined(call, bindings, expr) => - val tree = BetaReduce(original, expr, args) - if tree eq original then original - else cpy.Inlined(fn)(call, bindings, tree) + /** Rewrite an application + * + * ((x1, ..., xn) => b)(e1, ..., en) + * + * to + * + * val/def x1 = e1; ...; val/def xn = en; b + * + * where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix + * refs among the ei's directly without creating an intermediate binding. + * + * Similarly, rewrites type applications + * + * ([X1, ..., Xm] => (x1, ..., xn) => b).apply[T1, .., Tm](e1, ..., en) + * + * to + * + * type X1 = T1; ...; type Xm = Tm;val/def x1 = e1; ...; val/def xn = en; b + * + * This beta-reduction preserves the integrity of `Inlined` tree nodes. + */ + def apply(tree: Tree)(using Context): Tree = + val bindingsBuf = new ListBuffer[DefTree] + def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match + case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol => + Some(reduceApplication(ddef, argss, bindingsBuf)) + case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty => + template.body match + case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf)) + case _ => None + case Block(stats, expr) if stats.forall(isPureBinding) => + recur(expr, argss).map(cpy.Block(fn)(stats, _)) + case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) => + recur(expr, argss).map(cpy.Inlined(fn)(call, bindings, _)) + case Typed(expr, tpt) => + recur(expr, argss) + case TypeApply(Select(expr, nme.asInstanceOfPM), List(tpt)) => + recur(expr, argss) + case _ => None + tree match + case Apply(Select(fn, nme.apply), args) if defn.isFunctionType(fn.tpe) => + recur(fn, List(args)) match + case Some(reduced) => + seq(bindingsBuf.result(), reduced).withSpan(tree.span) + case None => + tree + case Apply(TypeApply(Select(fn, nme.apply), targs), args) if fn.tpe.typeSymbol eq dotc.core.Symbols.defn.PolyFunctionClass => + recur(fn, List(targs, args)) match + case Some(reduced) => + seq(bindingsBuf.result(), reduced).withSpan(tree.span) + case None => + tree case _ => - original - end apply - - /** Beta-reduces a call to `ddef` with arguments `args` */ - def apply(ddef: DefDef, args: List[Tree])(using Context) = - val bindings = new ListBuffer[ValDef]() - val expansion1 = reduceApplication(ddef, args, bindings) - val bindings1 = bindings.result() - seq(bindings1, expansion1) + tree /** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */ - def reduceApplication(ddef: DefDef, args: List[Tree], bindings: ListBuffer[ValDef])(using Context): Tree = - val vparams = ddef.termParamss.iterator.flatten.toList - assert(args.hasSameLengthAs(vparams)) + def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree = + val (targs, args) = argss.flatten.partition(_.isType) + val tparams = ddef.leadingTypeParams + val vparams = ddef.termParamss.flatten + + val targSyms = + for (targ, tparam) <- targs.zip(tparams) yield + targ.tpe.dealias match + case ref @ TypeRef(NoPrefix, _) => + ref.symbol + case _ => + val binding = TypeDef(newSymbol(ctx.owner, tparam.name, EmptyFlags, targ.tpe, coord = targ.span)).withSpan(targ.span) + bindings += binding + binding.symbol + val argSyms = for (arg, param) <- args.zip(vparams) yield arg.tpe.dealias match @@ -99,8 +137,8 @@ object BetaReduce: val expansion = TreeTypeMap( oldOwners = ddef.symbol :: Nil, newOwners = ctx.owner :: Nil, - substFrom = vparams.map(_.symbol), - substTo = argSyms + substFrom = (tparams ::: vparams).map(_.symbol), + substTo = targSyms ::: argSyms ).transform(ddef.rhs) val expansion1 = new TreeMap { diff --git a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala index 6edb60a77245..798f34757b35 100644 --- a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala +++ b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala @@ -8,6 +8,8 @@ import Symbols._, Contexts._, Types._, Decorators._ import NameOps._ import Names._ +import scala.collection.mutable.ListBuffer + /** Rewrite an application * * {new { def unapply(x0: X0)(x1: X1,..., xn: Xn) = b }}.unapply(y0)(y1, ..., yn) @@ -38,7 +40,7 @@ class InlinePatterns extends MiniPhase: if app.symbol.name.isUnapplyName && !app.tpe.isInstanceOf[MethodicType] then app match case App(Select(fn, name), argss) => - val app1 = betaReduce(app, fn, name, argss.flatten) + val app1 = betaReduce(app, fn, name, argss) if app1 ne app then report.log(i"beta reduce $app -> $app1") app1 case _ => @@ -51,11 +53,16 @@ class InlinePatterns extends MiniPhase: case Apply(App(fn, argss), args) => (fn, argss :+ args) case _ => (app, Nil) - private def betaReduce(tree: Apply, fn: Tree, name: Name, args: List[Tree])(using Context): Tree = + // TODO merge with BetaReduce.scala + private def betaReduce(tree: Apply, fn: Tree, name: Name, argss: List[List[Tree]])(using Context): Tree = fn match case Block(TypeDef(_, template: Template) :: Nil, Apply(Select(New(_),_), Nil)) if template.constr.rhs.isEmpty => template.body match - case List(ddef @ DefDef(`name`, _, _, _)) => BetaReduce(ddef, args) + case List(ddef @ DefDef(`name`, _, _, _)) => + val bindings = new ListBuffer[DefTree]() + val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings) + val bindings1 = bindings.result() + seq(bindings1, expansion1) case _ => tree case _ => tree diff --git a/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala b/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala index 21fc27cec0dd..c56bac4d66af 100644 --- a/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala +++ b/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala @@ -322,7 +322,10 @@ object PickleQuotes { } val Block(List(ddef: DefDef), _) = splice: @unchecked // TODO: beta reduce inner closure? Or wait until BetaReduce phase? - BetaReduce(ddef, spliceArgs).select(nme.apply).appliedTo(args(2).asInstance(quotesType)) + BetaReduce( + splice + .select(nme.apply).appliedToArgs(spliceArgs)) + .select(nme.apply).appliedTo(args(2).asInstance(quotesType)) } CaseDef(Literal(Constant(idx)), EmptyTree, rhs) } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 1ec13ba832c9..15ed447fd680 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -362,16 +362,15 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler object Term extends TermModule: def betaReduce(tree: Term): Option[Term] = tree match - case app @ tpd.Apply(tpd.Select(fn, nme.apply), args) if dotc.core.Symbols.defn.isFunctionType(fn.tpe) => - val app1 = dotc.transform.BetaReduce(app, fn, args) - if app1 eq app then None - else Some(app1.withSpan(tree.span)) case tpd.Block(Nil, expr) => for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e) case tpd.Inlined(_, Nil, expr) => betaReduce(expr) case _ => - None + val tree1 = dotc.transform.BetaReduce(tree) + if tree1 eq tree then None + else Some(tree1.withSpan(tree.span)) + end Term given TermMethods: TermMethods with diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index 33e898718b33..5f9318c0c3f0 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -582,6 +582,63 @@ class InlineBytecodeTests extends DottyBytecodeTest { } } + @Test def beta_reduce_polymorphic_function = { + val source = """class Test: + | def test = + | ([Z] => (arg: Z) => { val a: Z = arg; a }).apply[Int](2) + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = + List( + Op(ICONST_2), + VarOp(ISTORE, 1), + VarOp(ILOAD, 1), + Op(IRETURN) + ) + + assert(instructions == expected, + "`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected)) + + } + } + + @Test def beta_reduce_function_of_opaque_types = { + val source = """object foo: + | opaque type T = Int + | inline def apply(inline op: T => T): T = op(2) + | + |class Test: + | def test = foo { n => n } + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = + List( + Field(GETSTATIC, "foo$", "MODULE$", "Lfoo$;"), + VarOp(ASTORE, 1), + VarOp(ALOAD, 1), + VarOp(ASTORE, 2), + Op(ICONST_2), + Op(IRETURN), + ) + + assert(instructions == expected, + "`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected)) + + } + } + @Test def i9456 = { val source = """class Foo { | def test: Int = inline2(inline1(2.+)) diff --git a/tests/run-macros/i15968.check b/tests/run-macros/i15968.check new file mode 100644 index 000000000000..c7f3847d404c --- /dev/null +++ b/tests/run-macros/i15968.check @@ -0,0 +1,5 @@ +{ + type Z = java.lang.String + "foo".toString() +} +"foo".toString() diff --git a/tests/run-macros/i15968/Macro_1.scala b/tests/run-macros/i15968/Macro_1.scala new file mode 100644 index 000000000000..ea2728840d6e --- /dev/null +++ b/tests/run-macros/i15968/Macro_1.scala @@ -0,0 +1,15 @@ +import scala.quoted.* + +inline def macroPolyFun[A](inline arg: A, inline f: [Z] => Z => String): String = + ${ macroPolyFunImpl[A]('arg, 'f) } + +private def macroPolyFunImpl[A: Type](arg: Expr[A], f: Expr[[Z] => Z => String])(using Quotes): Expr[String] = + Expr(Expr.betaReduce('{ $f($arg) }).show) + + +inline def macroFun[A](inline arg: A, inline f: A => String): String = + ${ macroFunImpl[A]('arg, 'f) } + +private def macroFunImpl[A: Type](arg: Expr[A], f: Expr[A => String])(using Quotes): Expr[String] = + Expr(Expr.betaReduce('{ $f($arg) }).show) + diff --git a/tests/run-macros/i15968/Test_2.scala b/tests/run-macros/i15968/Test_2.scala new file mode 100644 index 000000000000..6c6826f96b34 --- /dev/null +++ b/tests/run-macros/i15968/Test_2.scala @@ -0,0 +1,3 @@ +@main def Test: Unit = + println(macroPolyFun("foo", [Z] => (arg: Z) => arg.toString)) + println(macroFun("foo", arg => arg.toString)) diff --git a/tests/run-macros/inline-beta-reduce-polyfunction.check b/tests/run-macros/inline-beta-reduce-polyfunction.check new file mode 100644 index 000000000000..7793e273864f --- /dev/null +++ b/tests/run-macros/inline-beta-reduce-polyfunction.check @@ -0,0 +1,7 @@ +{ + type X = Int + { + println(1) + 1 + } +} diff --git a/tests/run-macros/inline-beta-reduce-polyfunction.scala b/tests/run-macros/inline-beta-reduce-polyfunction.scala new file mode 100644 index 000000000000..60ef889e7260 --- /dev/null +++ b/tests/run-macros/inline-beta-reduce-polyfunction.scala @@ -0,0 +1,5 @@ +transparent inline def foo(inline f: [X] => X => X): Int = f[Int](1) + +@main def Test: Unit = + val code = compiletime.codeOf(foo([X] => (x: X) => { println(x); x })) + println(code) \ No newline at end of file