Skip to content

Commit 46b815d

Browse files
committed
Allow sealing method references as function types
1 parent 0d4b138 commit 46b815d

File tree

5 files changed

+134
-7
lines changed

5 files changed

+134
-7
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/QuotedOpsImpl.scala

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
package dotty.tools.dotc.tastyreflect
22

3+
import dotty.tools.dotc.ast.tpd
4+
import dotty.tools.dotc.ast.Trees
5+
import dotty.tools.dotc.core.Flags._
6+
import dotty.tools.dotc.core.Symbols.defn
7+
import dotty.tools.dotc.core.StdNames.nme
38
import dotty.tools.dotc.core.quoted.PickledQuotes
9+
import dotty.tools.dotc.core.Types.MethodType
410

511
trait QuotedOpsImpl extends scala.tasty.reflect.QuotedOps with CoreImpl {
612

@@ -15,16 +21,28 @@ trait QuotedOpsImpl extends scala.tasty.reflect.QuotedOps with CoreImpl {
1521
def TermToQuoteDeco(term: Term): TermToQuotedAPI = new TermToQuotedAPI {
1622

1723
def seal[T: scala.quoted.Type](implicit ctx: Context): scala.quoted.Expr[T] = {
18-
typecheck()
19-
new scala.quoted.Exprs.TastyTreeExpr(term).asInstanceOf[scala.quoted.Expr[T]]
20-
}
2124

22-
private def typecheck[T: scala.quoted.Type]()(implicit ctx: Context): Unit = {
23-
val tpt = QuotedTypeDeco(implicitly[scala.quoted.Type[T]]).unseal
24-
if (!(term.tpe <:< tpt.tpe)) {
25+
val expectedType = QuotedTypeDeco(implicitly[scala.quoted.Type[T]]).unseal.tpe
26+
27+
def etaExpand(term: Term): Term = term.tpe.widen match {
28+
case mtpe: MethodType =>
29+
val closureResType = mtpe.resType match {
30+
case t: MethodType => t.toFunctionType()
31+
case t => t
32+
}
33+
val closureTpe = MethodType(mtpe.paramNames, mtpe.paramInfos, closureResType)
34+
val closureMethod = ctx.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, closureTpe)
35+
tpd.Closure(closureMethod, tss => etaExpand(new tpd.TreeOps(term).appliedToArgs(tss.head)))
36+
case _ => term
37+
}
38+
39+
val expanded = etaExpand(term)
40+
if (expanded.tpe <:< expectedType) {
41+
new scala.quoted.Exprs.TastyTreeExpr(expanded).asInstanceOf[scala.quoted.Expr[T]]
42+
} else {
2543
throw new scala.tasty.TastyTypecheckError(
2644
s"""Term: ${term.show}
27-
|did not conform to type: ${tpt.tpe.show}
45+
|did not conform to type: ${expectedType.show}
2846
|""".stripMargin
2947
)
3048
}

compiler/src/dotty/tools/dotc/tastyreflect/TypeOrBoundsOpsImpl.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@ trait TypeOrBoundsOpsImpl extends scala.tasty.reflect.TypeOrBoundsOps with CoreI
77
def TypeDeco(tpe: Type): TypeAPI = new TypeAPI {
88
def =:=(other: Type)(implicit ctx: Context): Boolean = tpe =:= other
99
def <:<(other: Type)(implicit ctx: Context): Boolean = tpe <:< other
10+
11+
/** Widen from singleton type to its underlying non-singleton
12+
* base type by applying one or more `underlying` dereferences,
13+
* Also go from => T to T.
14+
* Identity for all other types. Example:
15+
*
16+
* class Outer { class C ; val x: C }
17+
* def o: Outer
18+
* <o.x.type>.widen = o.C
19+
*/
20+
def widen(implicit ctx: Context): Type = tpe.widen
21+
1022
}
1123

1224
def ConstantTypeDeco(x: ConstantType): Type.ConstantTypeAPI = new Type.ConstantTypeAPI {

library/src/scala/tasty/reflect/TypeOrBoundsOps.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ trait TypeOrBoundsOps extends Core {
5252
trait TypeAPI {
5353
def =:=(other: Type)(implicit ctx: Context): Boolean
5454
def <:<(other: Type)(implicit ctx: Context): Boolean
55+
def widen(implicit ctx: Context): Type
5556
}
5657

5758
val IsType: IsTypeModule
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import scala.quoted._
2+
3+
import scala.tasty._
4+
5+
object Asserts {
6+
7+
// TODO remove
8+
implicit val toolbox: scala.quoted.Toolbox = scala.quoted.Toolbox.make
9+
10+
inline def zeroLastArgs(x: => Int): Int =
11+
~zeroLastArgsImpl('(x))
12+
13+
def zeroLastArgsImpl(x: Expr[Int])(implicit reflect: Reflection): Expr[Int] = {
14+
import reflect._
15+
x.unseal.underlyingArgument match {
16+
case Term.Apply(fn, args) =>
17+
fn.tpe.widen match {
18+
case Type.IsMethodType(_) =>
19+
args.size match {
20+
case 1 => fn.seal[Int => Int].apply('(0))
21+
case 2 => fn.seal[(Int, Int) => Int].apply('(0), '(0))
22+
case 3 => fn.seal[(Int, Int, Int) => Int].apply('(0), '(0), '(0))
23+
}
24+
}
25+
case _ => x
26+
}
27+
}
28+
29+
inline def zeroAllArgs(x: => Int): Int =
30+
~zeroAllArgsImpl('(x))
31+
32+
def zeroAllArgsImpl(x: Expr[Int])(implicit reflect: Reflection): Expr[Int] = {
33+
import reflect._
34+
35+
x.unseal.underlyingArgument match {
36+
case Term.Apply(Term.Apply(fn, args1), args2) =>
37+
fn.tpe.widen match {
38+
case Type.IsMethodType(_) =>
39+
args1.size match {
40+
case 1 =>
41+
// TODO fn.seal[Int => Any].apply('(0)).unseal... and then apply next part
42+
args2.size match {
43+
case 1 => fn.seal[Int => Int => Int].apply('(0)).apply('(0))
44+
case 2 => fn.seal[Int => (Int, Int) => Int].apply('(0)).apply('(0), '(0))
45+
case 3 => fn.seal[Int => (Int, Int, Int) => Int].apply('(0)).apply('(0), '(0), '(0))
46+
}
47+
case 2 =>
48+
args2.size match {
49+
case 1 => fn.seal[(Int, Int) => Int => Int].apply('(0), '(0)).apply('(0))
50+
case 2 => fn.seal[(Int, Int) => (Int, Int) => Int].apply('(0), '(0)).apply('(0), '(0))
51+
case 3 => fn.seal[(Int, Int) => (Int, Int, Int) => Int].apply('(0), '(0)).apply('(0), '(0), '(0))
52+
}
53+
case 3 =>
54+
args2.size match {
55+
case 1 => fn.seal[(Int, Int, Int) => Int => Int].apply('(0), '(0), '(0)).apply('(0))
56+
case 2 => fn.seal[(Int, Int, Int) => (Int, Int) => Int].apply('(0), '(0), '(0)).apply('(0), '(0))
57+
case 3 => fn.seal[(Int, Int, Int) => (Int, Int, Int) => Int].apply('(0), '(0), '(0)).apply('(0), '(0), '(0))
58+
}
59+
}
60+
}
61+
case Term.Apply(fn, args) =>
62+
fn.tpe.widen match {
63+
case Type.IsMethodType(_) =>
64+
args.size match {
65+
case 1 => fn.seal[Int => Int].apply('(0))
66+
case 2 => fn.seal[(Int, Int) => Int].apply('(0), '(0))
67+
case 3 => fn.seal[(Int, Int, Int) => Int].apply('(0), '(0), '(0))
68+
}
69+
}
70+
case _ => x
71+
}
72+
}
73+
74+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
import Asserts._
3+
4+
object Test {
5+
def main(args: Array[String]): Unit = {
6+
assert(zeroLastArgs(-1) == -1)
7+
assert(zeroLastArgs(f1(2)) == 1)
8+
assert(zeroLastArgs(f2(2, 3)) == 2)
9+
assert(zeroLastArgs(f3(2)(4, 5)) == 5)
10+
11+
assert(zeroAllArgs(-1) == -1)
12+
assert(zeroAllArgs(f1(2)) == 1)
13+
assert(zeroAllArgs(f2(2, 3)) == 2)
14+
assert(zeroAllArgs(f3(2)(4, 5)) == 3)
15+
}
16+
17+
def f1(i: Int): Int = 1 + i
18+
def f2(i: Int, j: Int): Int = 2 + i + j
19+
def f3(i: Int)(j: Int, k: Int): Int = 3 + i + j
20+
def f4(i: Int, j: Int)(k: Int, l: Int): Int = 4 + i + j + k + l
21+
22+
}

0 commit comments

Comments
 (0)