Skip to content

Implement quoted Lambda extractor #8457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 8, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2051,44 +2051,51 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
}}
val argVals = argVals0.reverse
val argRefs = argRefs0.reverse
def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match {
case Typed(expr, tpt) =>
// we need to retain any type ascriptions we see and:
// a) if we succeed, ascribe the result type of the ascription to the inlined body
// b) if we fail, re-ascribe the same type to whatever it was we couldn't inline
// note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects
rec(expr, topAscription.orElse(Some(tpt)))
case Inlined(call, bindings, expansion) =>
// this case must go before closureDef to avoid dropping the inline node
cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription))
case closureDef(ddef) =>
val paramSyms = ddef.vparamss.head.map(param => param.symbol)
val paramToVals = paramSyms.zip(argRefs).toMap
val result = new TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
).transform(ddef.rhs)
topAscription match {
case Some(tpt) =>
// we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys)
val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType]
val reducedBody = lambdaExtractor(fn, argRefs.map(_.tpe)) match {
case Some(body) => body(argRefs)
case None => fn.select(nme.apply).appliedToArgs(argRefs)
}
seq(argVals, reducedBody).withSpan(fn.span)
}

def lambdaExtractor(fn: Term, paramTypes: List[Type])(using ctx: Context): Option[List[Term] => Term] = {
def rec(fn: Term, transformBody: Term => Term): Option[List[Term] => Term] = {
fn match {
case Inlined(call, bindings, expansion) =>
// this case must go before closureDef to avoid dropping the inline node
rec(expansion, cpy.Inlined(fn)(call, bindings, _))
case Typed(expr, tpt) =>
val tpe = tpt.tpe.dropDependentRefinement
// we checked that this is a plain Function closure, so there will be an apply method with a MethodType
// and the expected signature based on param types
val expectedSig = Signature.NotAMethod.prependTermParams(paramTypes, false)
val method = tpt.tpe.member(nme.apply).atSignature(expectedSig)
if method.symbol.is(Deferred) then
val methodType = method.info.asInstanceOf[MethodType]
// result might contain paramrefs, so we substitute them with arg termrefs
val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe))
Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)
case None =>
result
}
case tpd.Block(stats, expr) =>
seq(stats, rec(expr, topAscription)).withSpan(fn.span)
case _ =>
val maybeAscribed = topAscription match {
case Some(tpt) => Typed(fn, tpt).withSpan(fn.span)
case None => fn
}
maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
val resultTypeWithSubst = methodType.resultType.substParams(methodType, paramTypes)
rec(expr, Typed(_, TypeTree(resultTypeWithSubst).withSpan(tpt.span)))
else
None
case cl @ closureDef(ddef) =>
def replace(body: Term, argRefs: List[Term]): Term = {
val paramSyms = ddef.vparamss.head.map(param => param.symbol)
val paramToVals = paramSyms.zip(argRefs).toMap
new TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
).transform(body)
}
Some(argRefs => replace(transformBody(ddef.rhs), argRefs))
case Block(stats, expr) =>
// this case must go after closureDef to avoid matching the closure
rec(expr, cpy.Block(fn)(stats, _))
case _ =>
None
}
}
seq(argVals, rec(fn, None))
rec(fn, identity)
}

/////////////
Expand Down
31 changes: 31 additions & 0 deletions library/src/scala/quoted/matching/Lambda.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package scala.quoted
package matching

/** Lambda expression extractor */
object Lambda {

/** `case Lambda(body)` matche a lambda and extract the body.
* As the body may (will) contain references to the paramter, `body` is a function that recieves those arguments as `Expr`.
* Once this function is applied the result will be the body of the lambda with all references to the parameters replaced.
* If `body` is of type `(T1, T2, ...) => R` then body will be of type `(Expr[T1], Expr[T2], ...) => Expr[R]`.
*
* ```
* '{ (x: Int) => println(x) } match
* case Lambda(body) =>
* // where `body` is: (x: Expr[Int]) => '{ println($x) }
* body('{3}) // returns '{ println(3) }
* ```
*/
def unapply[F, Args <: Tuple, Res, G](expr: Expr[F])(using qctx: QuoteContext, tf: TupledFunction[F, Args => Res], tg: TupledFunction[G, Tuple.Map[Args, Expr] => Expr[Res]], functionType: Type[F]): Option[/*QuoteContext ?=>*/ G] = {
import qctx.tasty.{_, given _ }
val argTypes = functionType.unseal.tpe match
case AppliedType(_, functionArguments) => functionArguments.init.asInstanceOf[List[Type]]
qctx.tasty.internal.lambdaExtractor(expr.unseal, argTypes).map { fn =>
def f(args: Tuple.Map[Args, Expr]): Expr[Res] =
fn(args.toArray.map(_.asInstanceOf[Expr[Any]].unseal).toList).seal.asInstanceOf[Expr[Res]]
tg.untupled(f)
}

}

}
2 changes: 2 additions & 0 deletions library/src/scala/tasty/reflect/CompilerInterface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1543,4 +1543,6 @@ trait CompilerInterface {
*/
def betaReduce(f: Term, args: List[Term])(using ctx: Context): Term

def lambdaExtractor(term: Term, paramTypes: List[Type])(using ctx: Context): Option[List[Term] => Term]

}
3 changes: 3 additions & 0 deletions tests/run-macros/beta-reduce-inline-result.check
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ run-time: 4
compile-time: 1
run-time: 1
run-time: 5
run-time: 7
run-time: -1
run-time: 9
48 changes: 46 additions & 2 deletions tests/run-macros/beta-reduce-inline-result/Test_2.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import scala.compiletime._

object Test {

inline def dummy1: Int => Int =
(i: Int) => i + 1

Expand All @@ -14,6 +14,36 @@ object Test {
inline def dummy4: Int => Int =
???

object I extends (Int => Int) {
def apply(i: Int): i.type = i
}

abstract class II extends (Int => Int) {
val apply = 123
}

inline def dummy5: II =
(i: Int) => i + 1

abstract class III extends (Int => Int) {
def impl(i: Int): Int
def apply(i: Int): Int = -1
}

inline def dummy6: III =
(i: Int) => i + 1

abstract class IV extends (Int => Int) {
def apply(s: String): String
}

abstract class V extends IV {
def apply(s: String): String = "gotcha"
}

inline def dummy7: IV =
{ (i: Int) => i + 1 } : V

def main(argv : Array[String]) : Unit = {
println(code"compile-time: ${Macros.betaReduce(dummy1)(3)}")
println(s"run-time: ${Macros.betaReduce(dummy1)(3)}")
Expand All @@ -27,7 +57,21 @@ object Test {
def throwsNotImplemented2 = Macros.betaReduce(dummy4)(6)

// make sure paramref types work when inlining is not possible
println(s"run-time: ${Macros.betaReduce(Predef.identity)(5)}")
println(s"run-time: ${Macros.betaReduce(I)(5)}")

// -- cases below are non-function types, which are currently not inlined for simplicity but may be in the future
// (also, this tests that we return something valid when we see a closure that we can't inline)

// A non-function type with an apply value that can be confused with the apply method.
println(s"run-time: ${Macros.betaReduce(dummy5)(6)}")

// should print -1 (without inlining), because the apparent apply method actually
// has nothing to do with the function literal
println(s"run-time: ${Macros.betaReduce(dummy6)(7)}")

// the literal does contain the implementation of the apply method, but there are two abstract apply methods
// in the outermost abstract type
println(s"run-time: ${Macros.betaReduce(dummy7)(8)}")
}
}

6 changes: 6 additions & 0 deletions tests/run-macros/lambda-extractor-1.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
scala.Predef.identity[scala.Int](1)
1
{
scala.Predef.println(1)
1
}
11 changes: 11 additions & 0 deletions tests/run-macros/lambda-extractor-1/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import scala.quoted._
import scala.quoted.matching._

inline def test(inline f: Int => Int): String = ${ impl('f) }

def impl(using QuoteContext)(f: Expr[Int => Int]): Expr[String] = {
Expr(f match {
case Lambda(body) => body('{1}).show
case _ => f.show
})
}
6 changes: 6 additions & 0 deletions tests/run-macros/lambda-extractor-1/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

@main def Test = {
println(test(identity))
println(test(x => x))
println(test(x => { println(x); x }))
}
5 changes: 5 additions & 0 deletions tests/run-macros/lambda-extractor-2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1.+(2)
{
scala.Predef.println(1)
2
}
11 changes: 11 additions & 0 deletions tests/run-macros/lambda-extractor-2/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import scala.quoted._
import scala.quoted.matching._

inline def test(inline f: (Int, Int) => Int): String = ${ impl('f) }

def impl(using QuoteContext)(f: Expr[(Int, Int) => Int]): Expr[String] = {
Expr(f match {
case Lambda(body) => body('{1}, '{2}).show
case _ => f.show
})
}
5 changes: 5 additions & 0 deletions tests/run-macros/lambda-extractor-2/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

@main def Test = {
println(test((x, y) => x + y))
println(test((x, y) => { println(x); y }))
}
4 changes: 2 additions & 2 deletions tests/run-staging/i3876-c.check
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

(f: scala.Function1[scala.Int, scala.Int] {
def apply(x: scala.Int): scala.Int
}).apply(3)
}
})
}.apply(3)