diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala index c5937196c40a..bc90e32e61b3 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala @@ -2051,44 +2051,51 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend }} val argVals = argVals0.reverse val argRefs = argRefs0.reverse - def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match { - case Typed(expr, tpt) => - // we need to retain any type ascriptions we see and: - // a) if we succeed, ascribe the result type of the ascription to the inlined body - // b) if we fail, re-ascribe the same type to whatever it was we couldn't inline - // note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects - rec(expr, topAscription.orElse(Some(tpt))) - case Inlined(call, bindings, expansion) => - // this case must go before closureDef to avoid dropping the inline node - cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription)) - case closureDef(ddef) => - val paramSyms = ddef.vparamss.head.map(param => param.symbol) - val paramToVals = paramSyms.zip(argRefs).toMap - val result = new TreeTypeMap( - oldOwners = ddef.symbol :: Nil, - newOwners = ctx.owner :: Nil, - treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree) - ).transform(ddef.rhs) - topAscription match { - case Some(tpt) => - // we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys) - val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType] + val reducedBody = lambdaExtractor(fn, argRefs.map(_.tpe)) match { + case Some(body) => body(argRefs) + case None => fn.select(nme.apply).appliedToArgs(argRefs) + } + seq(argVals, reducedBody).withSpan(fn.span) + } + + def lambdaExtractor(fn: Term, paramTypes: List[Type])(using ctx: Context): Option[List[Term] => Term] = { + def rec(fn: Term, transformBody: Term => Term): Option[List[Term] => Term] = { + fn match { + case Inlined(call, bindings, expansion) => + // this case must go before closureDef to avoid dropping the inline node + rec(expansion, cpy.Inlined(fn)(call, bindings, _)) + case Typed(expr, tpt) => + val tpe = tpt.tpe.dropDependentRefinement + // we checked that this is a plain Function closure, so there will be an apply method with a MethodType + // and the expected signature based on param types + val expectedSig = Signature.NotAMethod.prependTermParams(paramTypes, false) + val method = tpt.tpe.member(nme.apply).atSignature(expectedSig) + if method.symbol.is(Deferred) then + val methodType = method.info.asInstanceOf[MethodType] // result might contain paramrefs, so we substitute them with arg termrefs - val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe)) - Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span) - case None => - result - } - case tpd.Block(stats, expr) => - seq(stats, rec(expr, topAscription)).withSpan(fn.span) - case _ => - val maybeAscribed = topAscription match { - case Some(tpt) => Typed(fn, tpt).withSpan(fn.span) - case None => fn - } - maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span) + val resultTypeWithSubst = methodType.resultType.substParams(methodType, paramTypes) + rec(expr, Typed(_, TypeTree(resultTypeWithSubst).withSpan(tpt.span))) + else + None + case cl @ closureDef(ddef) => + def replace(body: Term, argRefs: List[Term]): Term = { + val paramSyms = ddef.vparamss.head.map(param => param.symbol) + val paramToVals = paramSyms.zip(argRefs).toMap + new TreeTypeMap( + oldOwners = ddef.symbol :: Nil, + newOwners = ctx.owner :: Nil, + treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree) + ).transform(body) + } + Some(argRefs => replace(transformBody(ddef.rhs), argRefs)) + case Block(stats, expr) => + // this case must go after closureDef to avoid matching the closure + rec(expr, cpy.Block(fn)(stats, _)) + case _ => + None + } } - seq(argVals, rec(fn, None)) + rec(fn, identity) } ///////////// diff --git a/library/src/scala/quoted/matching/Lambda.scala b/library/src/scala/quoted/matching/Lambda.scala new file mode 100644 index 000000000000..dbc5bb159fe2 --- /dev/null +++ b/library/src/scala/quoted/matching/Lambda.scala @@ -0,0 +1,31 @@ +package scala.quoted +package matching + +/** Lambda expression extractor */ +object Lambda { + + /** `case Lambda(fn)` matches a lambda by lifting the function from `S => T` to `Expr[S] => Expr[T]`. + * As the body may (will) contain references to the paramter, `body` is a function that recieves those arguments as `Expr`. + * Once this function is applied the result will be the body of the lambda with all references to the parameters replaced. + * If `body` is of type `(T1, T2, ...) => R` then body will be of type `(Expr[T1], Expr[T2], ...) => Expr[R]`. + * + * ``` + * '{ (x: Int) => println(x) } match + * case Lambda(body) => + * // where `body` is: (x: Expr[Int]) => '{ println($x) } + * body('{3}) // returns '{ println(3) } + * ``` + */ + def unapply[F, Args <: Tuple, Res, G](expr: Expr[F])(using qctx: QuoteContext, tf: TupledFunction[F, Args => Res], tg: TupledFunction[G, Tuple.Map[Args, Expr] => Expr[Res]], functionType: Type[F]): Option[/*QuoteContext ?=>*/ G] = { + import qctx.tasty.{_, given _ } + val argTypes = functionType.unseal.tpe match + case AppliedType(_, functionArguments) => functionArguments.init.asInstanceOf[List[Type]] + qctx.tasty.internal.lambdaExtractor(expr.unseal, argTypes).map { fn => + def f(args: Tuple.Map[Args, Expr]): Expr[Res] = + fn(args.toArray.toList.map(_.asInstanceOf[Expr[Any]].unseal)).seal.asInstanceOf[Expr[Res]] + tg.untupled(f) + } + + } + +} diff --git a/library/src/scala/tasty/reflect/CompilerInterface.scala b/library/src/scala/tasty/reflect/CompilerInterface.scala index 6d372cfeccc8..64fc238c49b3 100644 --- a/library/src/scala/tasty/reflect/CompilerInterface.scala +++ b/library/src/scala/tasty/reflect/CompilerInterface.scala @@ -1543,4 +1543,6 @@ trait CompilerInterface { */ def betaReduce(f: Term, args: List[Term])(using ctx: Context): Term + def lambdaExtractor(term: Term, paramTypes: List[Type])(using ctx: Context): Option[List[Term] => Term] + } diff --git a/tests/run-macros/beta-reduce-inline-result.check b/tests/run-macros/beta-reduce-inline-result.check index 082514df02f7..3735f7520b82 100644 --- a/tests/run-macros/beta-reduce-inline-result.check +++ b/tests/run-macros/beta-reduce-inline-result.check @@ -3,3 +3,6 @@ run-time: 4 compile-time: 1 run-time: 1 run-time: 5 +run-time: 7 +run-time: -1 +run-time: 9 diff --git a/tests/run-macros/beta-reduce-inline-result/Test_2.scala b/tests/run-macros/beta-reduce-inline-result/Test_2.scala index 978b3e5d2f41..70c8b3e9e5ad 100644 --- a/tests/run-macros/beta-reduce-inline-result/Test_2.scala +++ b/tests/run-macros/beta-reduce-inline-result/Test_2.scala @@ -1,7 +1,7 @@ import scala.compiletime._ object Test { - + inline def dummy1: Int => Int = (i: Int) => i + 1 @@ -14,6 +14,36 @@ object Test { inline def dummy4: Int => Int = ??? + object I extends (Int => Int) { + def apply(i: Int): i.type = i + } + + abstract class II extends (Int => Int) { + val apply = 123 + } + + inline def dummy5: II = + (i: Int) => i + 1 + + abstract class III extends (Int => Int) { + def impl(i: Int): Int + def apply(i: Int): Int = -1 + } + + inline def dummy6: III = + (i: Int) => i + 1 + + abstract class IV extends (Int => Int) { + def apply(s: String): String + } + + abstract class V extends IV { + def apply(s: String): String = "gotcha" + } + + inline def dummy7: IV = + { (i: Int) => i + 1 } : V + def main(argv : Array[String]) : Unit = { println(code"compile-time: ${Macros.betaReduce(dummy1)(3)}") println(s"run-time: ${Macros.betaReduce(dummy1)(3)}") @@ -27,7 +57,21 @@ object Test { def throwsNotImplemented2 = Macros.betaReduce(dummy4)(6) // make sure paramref types work when inlining is not possible - println(s"run-time: ${Macros.betaReduce(Predef.identity)(5)}") + println(s"run-time: ${Macros.betaReduce(I)(5)}") + + // -- cases below are non-function types, which are currently not inlined for simplicity but may be in the future + // (also, this tests that we return something valid when we see a closure that we can't inline) + + // A non-function type with an apply value that can be confused with the apply method. + println(s"run-time: ${Macros.betaReduce(dummy5)(6)}") + + // should print -1 (without inlining), because the apparent apply method actually + // has nothing to do with the function literal + println(s"run-time: ${Macros.betaReduce(dummy6)(7)}") + + // the literal does contain the implementation of the apply method, but there are two abstract apply methods + // in the outermost abstract type + println(s"run-time: ${Macros.betaReduce(dummy7)(8)}") } } diff --git a/tests/run-macros/lambda-extractor-1.check b/tests/run-macros/lambda-extractor-1.check new file mode 100644 index 000000000000..d04e335f82f7 --- /dev/null +++ b/tests/run-macros/lambda-extractor-1.check @@ -0,0 +1,6 @@ +scala.Predef.identity[scala.Int](1) +1 +{ + scala.Predef.println(1) + 1 +} diff --git a/tests/run-macros/lambda-extractor-1/Macro_1.scala b/tests/run-macros/lambda-extractor-1/Macro_1.scala new file mode 100644 index 000000000000..c191a84c3792 --- /dev/null +++ b/tests/run-macros/lambda-extractor-1/Macro_1.scala @@ -0,0 +1,11 @@ +import scala.quoted._ +import scala.quoted.matching._ + +inline def test(inline f: Int => Int): String = ${ impl('f) } + +def impl(using QuoteContext)(f: Expr[Int => Int]): Expr[String] = { + Expr(f match { + case Lambda(body) => body('{1}).show + case _ => f.show + }) +} diff --git a/tests/run-macros/lambda-extractor-1/Test_2.scala b/tests/run-macros/lambda-extractor-1/Test_2.scala new file mode 100644 index 000000000000..e4a0ed7fd8b7 --- /dev/null +++ b/tests/run-macros/lambda-extractor-1/Test_2.scala @@ -0,0 +1,6 @@ + +@main def Test = { + println(test(identity)) + println(test(x => x)) + println(test(x => { println(x); x })) +} diff --git a/tests/run-macros/lambda-extractor-2.check b/tests/run-macros/lambda-extractor-2.check new file mode 100644 index 000000000000..e42b39f01040 --- /dev/null +++ b/tests/run-macros/lambda-extractor-2.check @@ -0,0 +1,5 @@ +1.+(2) +{ + scala.Predef.println(1) + 2 +} diff --git a/tests/run-macros/lambda-extractor-2/Macro_1.scala b/tests/run-macros/lambda-extractor-2/Macro_1.scala new file mode 100644 index 000000000000..0798e9b33990 --- /dev/null +++ b/tests/run-macros/lambda-extractor-2/Macro_1.scala @@ -0,0 +1,11 @@ +import scala.quoted._ +import scala.quoted.matching._ + +inline def test(inline f: (Int, Int) => Int): String = ${ impl('f) } + +def impl(using QuoteContext)(f: Expr[(Int, Int) => Int]): Expr[String] = { + Expr(f match { + case Lambda(body) => body('{1}, '{2}).show + case _ => f.show + }) +} diff --git a/tests/run-macros/lambda-extractor-2/Test_2.scala b/tests/run-macros/lambda-extractor-2/Test_2.scala new file mode 100644 index 000000000000..25719f80c843 --- /dev/null +++ b/tests/run-macros/lambda-extractor-2/Test_2.scala @@ -0,0 +1,5 @@ + +@main def Test = { + println(test((x, y) => x + y)) + println(test((x, y) => { println(x); y })) +} diff --git a/tests/run-staging/i3876-c.check b/tests/run-staging/i3876-c.check index dca23bcfdf11..38c85ed40818 100644 --- a/tests/run-staging/i3876-c.check +++ b/tests/run-staging/i3876-c.check @@ -6,5 +6,5 @@ (f: scala.Function1[scala.Int, scala.Int] { def apply(x: scala.Int): scala.Int - }).apply(3) -} + }) +}.apply(3)