Skip to content

Commit 0094097

Browse files
Merge pull request #3880 from dotty-staging/fix-#3876
Fix #3876: Implement Expr.AsFunction
2 parents 6ad8d72 + 2dc8bfa commit 0094097

File tree

5 files changed

+134
-17
lines changed

5 files changed

+134
-17
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,18 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
534534
}
535535
}
536536

537+
/** An extractor for def of a closure contained the block of the closure. */
538+
object closureDef {
539+
def unapply(tree: Tree): Option[DefDef] = tree match {
540+
case Block((meth @ DefDef(nme.ANON_FUN, _, _, _, _)) :: Nil, closure: Closure) =>
541+
Some(meth)
542+
case _ => None
543+
}
544+
}
545+
537546
/** If tree is a closure, its body, otherwise tree itself */
538547
def closureBody(tree: Tree)(implicit ctx: Context): Tree = tree match {
539-
case Block((meth @ DefDef(nme.ANON_FUN, _, _, _, _)) :: Nil, Closure(_, _, _)) => meth.rhs
548+
case closureDef(meth) => meth.rhs
540549
case _ => tree
541550
}
542551

compiler/src/dotty/tools/dotc/core/quoted/PickledQuotes.scala

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package dotty.tools.dotc.core.quoted
22

33
import dotty.tools.dotc.ast.Trees._
4-
import dotty.tools.dotc.ast.{tpd, untpd}
4+
import dotty.tools.dotc.ast.tpd
55
import dotty.tools.dotc.config.Printers._
66
import dotty.tools.dotc.core.Constants.Constant
77
import dotty.tools.dotc.core.Contexts._
88
import dotty.tools.dotc.core.Decorators._
99
import dotty.tools.dotc.core.Flags._
10+
import dotty.tools.dotc.core.NameKinds
1011
import dotty.tools.dotc.core.StdNames._
1112
import dotty.tools.dotc.core.Symbols._
1213
import dotty.tools.dotc.core.tasty.{TastyPickler, TastyPrinter, TastyString}
@@ -33,21 +34,16 @@ object PickledQuotes {
3334

3435
/** Transform the expression into its fully spliced Tree */
3536
def quotedToTree(expr: quoted.Quoted)(implicit ctx: Context): Tree = expr match {
36-
case expr: quoted.TastyQuoted => unpickleQuote(expr)
37-
case expr: quoted.Liftable.ConstantExpr[_] => Literal(Constant(expr.value))
37+
case expr: quoted.TastyQuoted =>
38+
unpickleQuote(expr)
39+
case expr: quoted.Liftable.ConstantExpr[_] =>
40+
Literal(Constant(expr.value))
41+
case expr: quoted.Expr.FunctionAppliedTo[_, _] =>
42+
functionAppliedTo(quotedToTree(expr.f), quotedToTree(expr.x))
3843
case expr: quoted.Type.TaggedPrimitive[_] =>
39-
val tpe = expr.ct match {
40-
case ClassTag.Unit => defn.UnitType
41-
case ClassTag.Byte => defn.ByteType
42-
case ClassTag.Char => defn.CharType
43-
case ClassTag.Short => defn.ShortType
44-
case ClassTag.Int => defn.IntType
45-
case ClassTag.Long => defn.LongType
46-
case ClassTag.Float => defn.FloatType
47-
case ClassTag.Double => defn.FloatType
48-
}
49-
TypeTree(tpe)
50-
case expr: RawQuoted => expr.tree
44+
classTagToTypeTree(expr.ct)
45+
case expr: RawQuoted =>
46+
expr.tree
5147
}
5248

5349
/** Unpickle the tree contained in the TastyQuoted */
@@ -111,4 +107,42 @@ object PickledQuotes {
111107
}
112108
tree
113109
}
110+
111+
private def classTagToTypeTree(ct: ClassTag[_])(implicit ctx: Context): TypeTree = {
112+
val tpe = ct match {
113+
case ClassTag.Unit => defn.UnitType
114+
case ClassTag.Byte => defn.ByteType
115+
case ClassTag.Char => defn.CharType
116+
case ClassTag.Short => defn.ShortType
117+
case ClassTag.Int => defn.IntType
118+
case ClassTag.Long => defn.LongType
119+
case ClassTag.Float => defn.FloatType
120+
case ClassTag.Double => defn.FloatType
121+
}
122+
TypeTree(tpe)
123+
}
124+
125+
private def functionAppliedTo(f: Tree, x: Tree)(implicit ctx: Context): Tree = {
126+
val x1 = SyntheticValDef(NameKinds.UniqueName.fresh("x".toTermName), x)
127+
def x1Ref() = ref(x1.symbol)
128+
def rec(f: Tree): Tree = f match {
129+
case closureDef(ddef) =>
130+
new TreeMap() {
131+
private val paramSym = ddef.vparamss.head.head.symbol
132+
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
133+
case tree: Ident if tree.symbol == paramSym => x1Ref().withPos(tree.pos)
134+
case _ => super.transform(tree)
135+
}
136+
}.transform(ddef.rhs)
137+
case Block(stats, expr) =>
138+
val applied = rec(expr)
139+
if (stats.isEmpty) applied
140+
else Block(stats, applied)
141+
case Inlined(call, bindings, expansion) =>
142+
Inlined(call, bindings, rec(expansion))
143+
case _ =>
144+
f.select(nme.apply).appliedTo(x1Ref())
145+
}
146+
Block(x1 :: Nil, rec(f))
147+
}
114148
}

library/src/scala/quoted/Expr.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ object Expr {
1313
ev.toExpr(x)
1414

1515
implicit class AsFunction[T, U](private val f: Expr[T => U]) extends AnyVal {
16-
def apply(x: Expr[T]): Expr[U] = ???
16+
def apply(x: Expr[T]): Expr[U] = new FunctionAppliedTo[T, U](f, x)
1717
}
18+
19+
final class FunctionAppliedTo[T, U] private[Expr](val f: Expr[T => U], val x: Expr[T]) extends Expr[U]
1820
}

tests/run-with-compiler/i3876.check

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
6
2+
{
3+
val x$1: Int = 3
4+
{
5+
x$1.+(x$1)
6+
}
7+
}
8+
6
9+
{
10+
val x$1: Int = 3
11+
{
12+
def f(x: Int): Int = x.+(x)
13+
f(x$1)
14+
}
15+
}
16+
6
17+
{
18+
val x$1: Int = 3
19+
{
20+
val f:
21+
Function1[Int, Int]
22+
{
23+
def apply(x: Int): Int
24+
}
25+
=
26+
{
27+
(x: Int) => x.+(x)
28+
}
29+
(f: (x: Int) => Int).apply(x$1)
30+
}
31+
}
32+
6
33+
{
34+
val x$1: Int = 3
35+
/* inlined from Test*/
36+
{
37+
x$1.+(x$1)
38+
}
39+
}

tests/run-with-compiler/i3876.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import dotty.tools.dotc.quoted.Runners._
2+
import scala.quoted._
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
val x: Expr[Int] = '(3)
6+
7+
val f: Expr[Int => Int] = '{ (x: Int) => x + x }
8+
println(f(x).run)
9+
println(f(x).show)
10+
11+
val f2: Expr[Int => Int] = '{
12+
def f(x: Int): Int = x + x
13+
f
14+
}
15+
println(f2(x).run)
16+
println(f2(x).show)
17+
18+
val f3: Expr[Int => Int] = '{
19+
val f: (x: Int) => Int = x => x + x
20+
f
21+
}
22+
println(f3(x).run)
23+
println(f3(x).show) // TODO improve printer
24+
25+
val f4: Expr[Int => Int] = '{
26+
inlineLambda
27+
}
28+
println(f4(x).run)
29+
println(f4(x).show)
30+
}
31+
32+
inline def inlineLambda: Int => Int = x => x + x
33+
}

0 commit comments

Comments
 (0)