Skip to content

Commit 65ada3c

Browse files
Merge pull request #8457 from dotty-staging/add-lambda-extractor
Implement quoted Lambda extractor
2 parents 5db52c3 + 7446c65 commit 65ada3c

File tree

12 files changed

+171
-40
lines changed

12 files changed

+171
-40
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,44 +2051,51 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
20512051
}}
20522052
val argVals = argVals0.reverse
20532053
val argRefs = argRefs0.reverse
2054-
def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match {
2055-
case Typed(expr, tpt) =>
2056-
// we need to retain any type ascriptions we see and:
2057-
// a) if we succeed, ascribe the result type of the ascription to the inlined body
2058-
// b) if we fail, re-ascribe the same type to whatever it was we couldn't inline
2059-
// note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects
2060-
rec(expr, topAscription.orElse(Some(tpt)))
2061-
case Inlined(call, bindings, expansion) =>
2062-
// this case must go before closureDef to avoid dropping the inline node
2063-
cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription))
2064-
case closureDef(ddef) =>
2065-
val paramSyms = ddef.vparamss.head.map(param => param.symbol)
2066-
val paramToVals = paramSyms.zip(argRefs).toMap
2067-
val result = new TreeTypeMap(
2068-
oldOwners = ddef.symbol :: Nil,
2069-
newOwners = ctx.owner :: Nil,
2070-
treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
2071-
).transform(ddef.rhs)
2072-
topAscription match {
2073-
case Some(tpt) =>
2074-
// we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys)
2075-
val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType]
2054+
val reducedBody = lambdaExtractor(fn, argRefs.map(_.tpe)) match {
2055+
case Some(body) => body(argRefs)
2056+
case None => fn.select(nme.apply).appliedToArgs(argRefs)
2057+
}
2058+
seq(argVals, reducedBody).withSpan(fn.span)
2059+
}
2060+
2061+
def lambdaExtractor(fn: Term, paramTypes: List[Type])(using ctx: Context): Option[List[Term] => Term] = {
2062+
def rec(fn: Term, transformBody: Term => Term): Option[List[Term] => Term] = {
2063+
fn match {
2064+
case Inlined(call, bindings, expansion) =>
2065+
// this case must go before closureDef to avoid dropping the inline node
2066+
rec(expansion, cpy.Inlined(fn)(call, bindings, _))
2067+
case Typed(expr, tpt) =>
2068+
val tpe = tpt.tpe.dropDependentRefinement
2069+
// we checked that this is a plain Function closure, so there will be an apply method with a MethodType
2070+
// and the expected signature based on param types
2071+
val expectedSig = Signature.NotAMethod.prependTermParams(paramTypes, false)
2072+
val method = tpt.tpe.member(nme.apply).atSignature(expectedSig)
2073+
if method.symbol.is(Deferred) then
2074+
val methodType = method.info.asInstanceOf[MethodType]
20762075
// result might contain paramrefs, so we substitute them with arg termrefs
2077-
val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe))
2078-
Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)
2079-
case None =>
2080-
result
2081-
}
2082-
case tpd.Block(stats, expr) =>
2083-
seq(stats, rec(expr, topAscription)).withSpan(fn.span)
2084-
case _ =>
2085-
val maybeAscribed = topAscription match {
2086-
case Some(tpt) => Typed(fn, tpt).withSpan(fn.span)
2087-
case None => fn
2088-
}
2089-
maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
2076+
val resultTypeWithSubst = methodType.resultType.substParams(methodType, paramTypes)
2077+
rec(expr, Typed(_, TypeTree(resultTypeWithSubst).withSpan(tpt.span)))
2078+
else
2079+
None
2080+
case cl @ closureDef(ddef) =>
2081+
def replace(body: Term, argRefs: List[Term]): Term = {
2082+
val paramSyms = ddef.vparamss.head.map(param => param.symbol)
2083+
val paramToVals = paramSyms.zip(argRefs).toMap
2084+
new TreeTypeMap(
2085+
oldOwners = ddef.symbol :: Nil,
2086+
newOwners = ctx.owner :: Nil,
2087+
treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
2088+
).transform(body)
2089+
}
2090+
Some(argRefs => replace(transformBody(ddef.rhs), argRefs))
2091+
case Block(stats, expr) =>
2092+
// this case must go after closureDef to avoid matching the closure
2093+
rec(expr, cpy.Block(fn)(stats, _))
2094+
case _ =>
2095+
None
2096+
}
20902097
}
2091-
seq(argVals, rec(fn, None))
2098+
rec(fn, identity)
20922099
}
20932100

20942101
/////////////
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package scala.quoted
2+
package matching
3+
4+
/** Lambda expression extractor */
5+
object Lambda {
6+
7+
/** `case Lambda(fn)` matches a lambda by lifting the function from `S => T` to `Expr[S] => Expr[T]`.
8+
* As the body may (will) contain references to the paramter, `body` is a function that recieves those arguments as `Expr`.
9+
* Once this function is applied the result will be the body of the lambda with all references to the parameters replaced.
10+
* If `body` is of type `(T1, T2, ...) => R` then body will be of type `(Expr[T1], Expr[T2], ...) => Expr[R]`.
11+
*
12+
* ```
13+
* '{ (x: Int) => println(x) } match
14+
* case Lambda(body) =>
15+
* // where `body` is: (x: Expr[Int]) => '{ println($x) }
16+
* body('{3}) // returns '{ println(3) }
17+
* ```
18+
*/
19+
def unapply[F, Args <: Tuple, Res, G](expr: Expr[F])(using qctx: QuoteContext, tf: TupledFunction[F, Args => Res], tg: TupledFunction[G, Tuple.Map[Args, Expr] => Expr[Res]], functionType: Type[F]): Option[/*QuoteContext ?=>*/ G] = {
20+
import qctx.tasty.{_, given _ }
21+
val argTypes = functionType.unseal.tpe match
22+
case AppliedType(_, functionArguments) => functionArguments.init.asInstanceOf[List[Type]]
23+
qctx.tasty.internal.lambdaExtractor(expr.unseal, argTypes).map { fn =>
24+
def f(args: Tuple.Map[Args, Expr]): Expr[Res] =
25+
fn(args.toArray.toList.map(_.asInstanceOf[Expr[Any]].unseal)).seal.asInstanceOf[Expr[Res]]
26+
tg.untupled(f)
27+
}
28+
29+
}
30+
31+
}

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,4 +1543,6 @@ trait CompilerInterface {
15431543
*/
15441544
def betaReduce(f: Term, args: List[Term])(using ctx: Context): Term
15451545

1546+
def lambdaExtractor(term: Term, paramTypes: List[Type])(using ctx: Context): Option[List[Term] => Term]
1547+
15461548
}

tests/run-macros/beta-reduce-inline-result.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ run-time: 4
33
compile-time: 1
44
run-time: 1
55
run-time: 5
6+
run-time: 7
7+
run-time: -1
8+
run-time: 9

tests/run-macros/beta-reduce-inline-result/Test_2.scala

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import scala.compiletime._
22

33
object Test {
4-
4+
55
inline def dummy1: Int => Int =
66
(i: Int) => i + 1
77

@@ -14,6 +14,36 @@ object Test {
1414
inline def dummy4: Int => Int =
1515
???
1616

17+
object I extends (Int => Int) {
18+
def apply(i: Int): i.type = i
19+
}
20+
21+
abstract class II extends (Int => Int) {
22+
val apply = 123
23+
}
24+
25+
inline def dummy5: II =
26+
(i: Int) => i + 1
27+
28+
abstract class III extends (Int => Int) {
29+
def impl(i: Int): Int
30+
def apply(i: Int): Int = -1
31+
}
32+
33+
inline def dummy6: III =
34+
(i: Int) => i + 1
35+
36+
abstract class IV extends (Int => Int) {
37+
def apply(s: String): String
38+
}
39+
40+
abstract class V extends IV {
41+
def apply(s: String): String = "gotcha"
42+
}
43+
44+
inline def dummy7: IV =
45+
{ (i: Int) => i + 1 } : V
46+
1747
def main(argv : Array[String]) : Unit = {
1848
println(code"compile-time: ${Macros.betaReduce(dummy1)(3)}")
1949
println(s"run-time: ${Macros.betaReduce(dummy1)(3)}")
@@ -27,7 +57,21 @@ object Test {
2757
def throwsNotImplemented2 = Macros.betaReduce(dummy4)(6)
2858

2959
// make sure paramref types work when inlining is not possible
30-
println(s"run-time: ${Macros.betaReduce(Predef.identity)(5)}")
60+
println(s"run-time: ${Macros.betaReduce(I)(5)}")
61+
62+
// -- cases below are non-function types, which are currently not inlined for simplicity but may be in the future
63+
// (also, this tests that we return something valid when we see a closure that we can't inline)
64+
65+
// A non-function type with an apply value that can be confused with the apply method.
66+
println(s"run-time: ${Macros.betaReduce(dummy5)(6)}")
67+
68+
// should print -1 (without inlining), because the apparent apply method actually
69+
// has nothing to do with the function literal
70+
println(s"run-time: ${Macros.betaReduce(dummy6)(7)}")
71+
72+
// the literal does contain the implementation of the apply method, but there are two abstract apply methods
73+
// in the outermost abstract type
74+
println(s"run-time: ${Macros.betaReduce(dummy7)(8)}")
3175
}
3276
}
3377

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
scala.Predef.identity[scala.Int](1)
2+
1
3+
{
4+
scala.Predef.println(1)
5+
1
6+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def test(inline f: Int => Int): String = ${ impl('f) }
5+
6+
def impl(using QuoteContext)(f: Expr[Int => Int]): Expr[String] = {
7+
Expr(f match {
8+
case Lambda(body) => body('{1}).show
9+
case _ => f.show
10+
})
11+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
@main def Test = {
3+
println(test(identity))
4+
println(test(x => x))
5+
println(test(x => { println(x); x }))
6+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
1.+(2)
2+
{
3+
scala.Predef.println(1)
4+
2
5+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def test(inline f: (Int, Int) => Int): String = ${ impl('f) }
5+
6+
def impl(using QuoteContext)(f: Expr[(Int, Int) => Int]): Expr[String] = {
7+
Expr(f match {
8+
case Lambda(body) => body('{1}, '{2}).show
9+
case _ => f.show
10+
})
11+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
@main def Test = {
3+
println(test((x, y) => x + y))
4+
println(test((x, y) => { println(x); y }))
5+
}

tests/run-staging/i3876-c.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66

77
(f: scala.Function1[scala.Int, scala.Int] {
88
def apply(x: scala.Int): scala.Int
9-
}).apply(3)
10-
}
9+
})
10+
}.apply(3)

0 commit comments

Comments
 (0)