Skip to content

Commit 97945bc

Browse files
committed
Allow macros to generate method symbols, add missing method type constructors
Since there already exists the machinery to generate a new tree for an existing method symbol, this change simply adds the ability to generate a fresh method symbol with the given name, type and flags. Then, `DefDef.apply` can be used to generate the full tree. Missing type constructors for `PolyType` and `ByNameType` are added, as well as corresponding `.param(Int)` accessors in order to make references to a `MethodType` or a `PolyType`'s parameter types. Testing is achieved by synthesizing a series of method definitions, along with references to these methods and some basic correctness assertions. The purpose of each test is described in a comment above the corresponding block of code. Note 1: there are optional parameters for flags and private in the `Symbol.newMethod` function. These seem unnecessary for just defining local functions, but may be useful in synthesizing more complex declarations that may be supported later, such as local classes. Note 2: implicit/given method types are not exposed by this change. Since these types are structurally identical to normal method types, perhaps some optional flags could be added to the MethodType constructor for this purpose?
1 parent 9982f0d commit 97945bc

File tree

7 files changed

+249
-0
lines changed

7 files changed

+249
-0
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,6 +1372,8 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
13721372
case _ => None
13731373
}
13741374

1375+
def ByNameType_apply(underlying: Type)(given Context): Type = Types.ExprType(underlying)
1376+
13751377
def ByNameType_underlying(self: ByNameType)(given Context): Type = self.resType.stripTypeVar
13761378

13771379
type ParamRef = Types.ParamRef
@@ -1437,6 +1439,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
14371439

14381440
def MethodType_isErased(self: MethodType): Boolean = self.isErasedMethod
14391441
def MethodType_isImplicit(self: MethodType): Boolean = self.isImplicitMethod
1442+
def MethodType_param(self: MethodType, idx: Int)(given Context): Type = self.newParamRef(idx)
14401443
def MethodType_paramNames(self: MethodType)(given Context): List[String] = self.paramNames.map(_.toString)
14411444
def MethodType_paramTypes(self: MethodType)(given Context): List[Type] = self.paramInfos
14421445
def MethodType_resType(self: MethodType)(given Context): Type = self.resType
@@ -1450,6 +1453,10 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
14501453
case _ => None
14511454
}
14521455

1456+
def PolyType_apply(paramNames: List[String])(paramBoundsExp: PolyType => List[TypeBounds], resultTypeExp: PolyType => Type)(given Context): PolyType =
1457+
Types.PolyType(paramNames.map(_.toTypeName))(paramBoundsExp, resultTypeExp)
1458+
1459+
def PolyType_param(self: PolyType, idx: Int)(given Context): Type = self.newParamRef(idx)
14531460
def PolyType_paramNames(self: PolyType)(given Context): List[String] = self.paramNames.map(_.toString)
14541461
def PolyType_paramBounds(self: PolyType)(given Context): List[TypeBounds] = self.paramInfos
14551462
def PolyType_resType(self: PolyType)(given Context): Type = self.resType
@@ -1717,6 +1724,11 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
17171724
def Symbol_of(fullName: String)(given ctx: Context): Symbol =
17181725
ctx.requiredClass(fullName)
17191726

1727+
def Symbol_newMethod(parent: Symbol, name: String, flags: Flags, tpe: Type, privateWithin: Symbol)(given ctx: Context): Symbol = {
1728+
val computedFlags = flags | Flags.Method
1729+
ctx.newSymbol(parent, name.toTermName, computedFlags, tpe, privateWithin)
1730+
}
1731+
17201732
def Symbol_isTypeParam(self: Symbol)(given Context): Boolean =
17211733
self.isTypeParam
17221734

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,8 @@ trait CompilerInterface {
986986

987987
def isInstanceOfByNameType(given ctx: Context): IsInstanceOf[ByNameType]
988988

989+
def ByNameType_apply(underlying: Type)(given ctx: Context): Type
990+
989991
def ByNameType_underlying(self: ByNameType)(given ctx: Context): Type
990992

991993
/** Type of a parameter reference */
@@ -1031,6 +1033,7 @@ trait CompilerInterface {
10311033

10321034
def MethodType_isErased(self: MethodType): Boolean
10331035
def MethodType_isImplicit(self: MethodType): Boolean
1036+
def MethodType_param(self: MethodType, ids: Int)(given ctx: Context): Type
10341037
def MethodType_paramNames(self: MethodType)(given ctx: Context): List[String]
10351038
def MethodType_paramTypes(self: MethodType)(given ctx: Context): List[Type]
10361039
def MethodType_resType(self: MethodType)(given ctx: Context): Type
@@ -1040,6 +1043,9 @@ trait CompilerInterface {
10401043

10411044
def isInstanceOfPolyType(given ctx: Context): IsInstanceOf[PolyType]
10421045

1046+
def PolyType_apply(paramNames: List[String])(paramBoundsExp: PolyType => List[TypeBounds], resultTypeExp: PolyType => Type)(given ctx: Context): PolyType
1047+
1048+
def PolyType_param(self: PolyType, idx: Int)(given ctx: Context): Type
10431049
def PolyType_paramNames(self: PolyType)(given ctx: Context): List[String]
10441050
def PolyType_paramBounds(self: PolyType)(given ctx: Context): List[TypeBounds]
10451051
def PolyType_resType(self: PolyType)(given ctx: Context): Type
@@ -1264,6 +1270,8 @@ trait CompilerInterface {
12641270

12651271
def Symbol_of(fullName: String)(given ctx: Context): Symbol
12661272

1273+
def Symbol_newMethod(parent: Symbol, name: String, flags: Flags, tpe: Type, privateWithin: Symbol)(given ctx: Context): Symbol
1274+
12671275
def Symbol_isTypeParam(self: Symbol)(given ctx: Context): Boolean
12681276

12691277
def Symbol_isPackageDef(symbol: Symbol)(given ctx: Context): Boolean

library/src/scala/tasty/reflect/SymbolOps.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ trait SymbolOps extends Core { selfSymbolOps: FlagsOps =>
99
def classSymbol(fullName: String)(given ctx: Context): Symbol =
1010
internal.Symbol_of(fullName)
1111

12+
/** Generates a new method symbol with the given parent, name and type.
13+
*
14+
* This symbol starts without an accompanying definition.
15+
* It is the meta-programmer's responsibility to provide exactly one corresponding definition by passing
16+
* this symbol to the DefDef constructor.
17+
*
18+
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
19+
* direct or indirect children of the reflection context's owner. */
20+
def newMethod(parent: Symbol, name: String, tpe: Type, flags: Flags = Flags.EmptyFlags, privateWithin: Option[Symbol] = None)(given ctx: Context): Symbol =
21+
internal.Symbol_newMethod(parent, name, flags, tpe, privateWithin.getOrElse(noSymbol))
22+
1223
/** Definition not available */
1324
def noSymbol(given ctx: Context): Symbol =
1425
internal.Symbol_noSymbol

library/src/scala/tasty/reflect/TypeOrBoundsOps.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ trait TypeOrBoundsOps extends Core {
286286
def unapply(x: ByNameType)(given ctx: Context): Option[ByNameType] = Some(x)
287287

288288
object ByNameType {
289+
def apply(underlying: Type)(given ctx: Context): Type = internal.ByNameType_apply(underlying)
289290
def unapply(x: ByNameType)(given ctx: Context): Option[Type] = Some(x.underlying)
290291
}
291292

@@ -368,6 +369,7 @@ trait TypeOrBoundsOps extends Core {
368369
given MethodTypeOps: extension (self: MethodType) {
369370
def isImplicit: Boolean = internal.MethodType_isImplicit(self)
370371
def isErased: Boolean = internal.MethodType_isErased(self)
372+
def param(idx: Int)(given ctx: Context): Type = internal.MethodType_param(self, idx)
371373
def paramNames(given ctx: Context): List[String] = internal.MethodType_paramNames(self)
372374
def paramTypes(given ctx: Context): List[Type] = internal.MethodType_paramTypes(self)
373375
def resType(given ctx: Context): Type = internal.MethodType_resType(self)
@@ -380,11 +382,14 @@ trait TypeOrBoundsOps extends Core {
380382
def unapply(x: PolyType)(given ctx: Context): Option[PolyType] = Some(x)
381383

382384
object PolyType {
385+
def apply(paramNames: List[String])(paramBoundsExp: PolyType => List[TypeBounds], resultTypeExp: PolyType => Type)(given ctx: Context): PolyType =
386+
internal.PolyType_apply(paramNames)(paramBoundsExp, resultTypeExp)
383387
def unapply(x: PolyType)(given ctx: Context): Option[(List[String], List[TypeBounds], Type)] =
384388
Some((x.paramNames, x.paramBounds, x.resType))
385389
}
386390

387391
given PolyTypeOps: extension (self: PolyType) {
392+
def param(idx: Int)(given ctx: Context): Type = internal.PolyType_param(self, idx)
388393
def paramNames(given ctx: Context): List[String] = internal.PolyType_paramNames(self)
389394
def paramBounds(given ctx: Context): List[TypeBounds] = internal.PolyType_paramBounds(self)
390395
def resType(given ctx: Context): Type = internal.PolyType_resType(self)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
sym6_2: 6
2+
sym6_1: 5
3+
sym6_2: 4
4+
sym6_1: 3
5+
sym6_2: 2
6+
sym6_1: 1
7+
sym6_2: 0
8+
Ok
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import quoted._
2+
3+
object Macros {
4+
5+
inline def theTestBlock : Unit = ${ theTestBlockImpl }
6+
7+
def theTestBlockImpl(given qctx: QuoteContext) : Expr[Unit] = {
8+
import qctx.tasty.{_,given}
9+
10+
// simple smoke test
11+
val sym1 : Symbol = Symbol.newMethod(
12+
rootContext.owner,
13+
"sym1",
14+
MethodType(List("a","b"))(
15+
_ => List(typeOf[Int], typeOf[Int]),
16+
_ => typeOf[Int]))
17+
assert(sym1.isDefDef)
18+
assert(sym1.name == "sym1")
19+
val sym1Statements : List[Statement] = List(
20+
DefDef(sym1, {
21+
case List() => {
22+
case List(List(a, b)) =>
23+
Some('{ ${ a.seal.asInstanceOf[Expr[Int]] } - ${ b.seal.asInstanceOf[Expr[Int]] } }.unseal)
24+
}
25+
}),
26+
'{ assert(${ Apply(Ref(sym1), List(Literal(Constant(2)), Literal(Constant(3)))).seal.asInstanceOf[Expr[Int]] } == -1) }.unseal)
27+
28+
// test for no argument list (no Apply node)
29+
val sym2 : Symbol = Symbol.newMethod(
30+
rootContext.owner,
31+
"sym2",
32+
ByNameType(typeOf[Int]))
33+
assert(sym2.isDefDef)
34+
assert(sym2.name == "sym2")
35+
val sym2Statements : List[Statement] = List(
36+
DefDef(sym2, {
37+
case List() => {
38+
case List() =>
39+
Some(Literal(Constant(2)))
40+
}
41+
}),
42+
'{ assert(${ Ref(sym2).seal.asInstanceOf[Expr[Int]] } == 2) }.unseal)
43+
44+
// test for multiple argument lists
45+
val sym3 : Symbol = Symbol.newMethod(
46+
rootContext.owner,
47+
"sym3",
48+
MethodType(List("a"))(
49+
_ => List(typeOf[Int]),
50+
mt => MethodType(List("b"))(
51+
_ => List(mt.param(0)),
52+
_ => mt.param(0))))
53+
assert(sym3.isDefDef)
54+
assert(sym3.name == "sym3")
55+
val sym3Statements : List[Statement] = List(
56+
DefDef(sym3, {
57+
case List() => {
58+
case List(List(a), List(b)) =>
59+
Some(a)
60+
}
61+
}),
62+
'{ assert(${ Apply(Apply(Ref(sym3), List(Literal(Constant(3)))), List(Literal(Constant(3)))).seal.asInstanceOf[Expr[Int]] } == 3) }.unseal)
63+
64+
// test for recursive references
65+
val sym4 : Symbol = Symbol.newMethod(
66+
rootContext.owner,
67+
"sym4",
68+
MethodType(List("x"))(
69+
_ => List(typeOf[Int]),
70+
_ => typeOf[Int]))
71+
assert(sym4.isDefDef)
72+
assert(sym4.name == "sym4")
73+
val sym4Statements : List[Statement] = List(
74+
DefDef(sym4, {
75+
case List() => {
76+
case List(List(x)) =>
77+
Some('{
78+
if ${ x.seal.asInstanceOf[Expr[Int]] } == 0
79+
then 0
80+
else ${ Apply(Ref(sym4), List('{ ${ x.seal.asInstanceOf[Expr[Int]] } - 1 }.unseal)).seal.asInstanceOf[Expr[Int]] }
81+
}.unseal)
82+
}
83+
}),
84+
'{ assert(${ Apply(Ref(sym4), List(Literal(Constant(4)))).seal.asInstanceOf[Expr[Int]] } == 0) }.unseal)
85+
86+
// test for nested functions (one symbol is the other's parent, and we use a Closure)
87+
val sym5 : Symbol = Symbol.newMethod(
88+
rootContext.owner,
89+
"sym5",
90+
MethodType(List("x"))(
91+
_ => List(typeOf[Int]),
92+
_ => typeOf[Int=>Int]))
93+
assert(sym5.isDefDef)
94+
assert(sym5.name == "sym5")
95+
val sym5Statements : List[Statement] = List(
96+
DefDef(sym5, {
97+
case List() => {
98+
case List(List(x)) =>
99+
Some {
100+
val sym51 : Symbol = Symbol.newMethod(
101+
sym5,
102+
"sym51",
103+
MethodType(List("x"))(
104+
_ => List(typeOf[Int]),
105+
_ => typeOf[Int]))
106+
Block(
107+
List(
108+
DefDef(sym51, {
109+
case List() => {
110+
case List(List(xx)) =>
111+
Some('{ ${ x.seal.asInstanceOf[Expr[Int]] } - ${ xx.seal.asInstanceOf[Expr[Int]] } }.unseal)
112+
}
113+
})),
114+
Closure(Ref(sym51), None))
115+
}
116+
}
117+
}),
118+
'{ assert(${ Apply(Ref(sym5), List(Literal(Constant(5)))).seal.asInstanceOf[Expr[Int=>Int]] }(4) == 1) }.unseal)
119+
120+
// test mutually recursive definitions
121+
val sym6_1 : Symbol = Symbol.newMethod(
122+
rootContext.owner,
123+
"sym6_1",
124+
MethodType(List("x"))(
125+
_ => List(typeOf[Int]),
126+
_ => typeOf[Int]))
127+
val sym6_2 : Symbol = Symbol.newMethod(
128+
rootContext.owner,
129+
"sym6_2",
130+
MethodType(List("x"))(
131+
_ => List(typeOf[Int]),
132+
_ => typeOf[Int]))
133+
assert(sym6_1.isDefDef)
134+
assert(sym6_2.isDefDef)
135+
assert(sym6_1.name == "sym6_1")
136+
assert(sym6_2.name == "sym6_2")
137+
val sym6Statements : List[Statement] = List(
138+
DefDef(sym6_1, {
139+
case List() => {
140+
case List(List(x)) =>
141+
Some {
142+
'{
143+
println(s"sym6_1: ${ ${ x.seal.asInstanceOf[Expr[Int]] } }")
144+
if ${ x.seal.asInstanceOf[Expr[Int]] } == 0
145+
then 0
146+
else ${ Apply(Ref(sym6_2), List('{ ${ x.seal.asInstanceOf[Expr[Int]] } - 1 }.unseal)).seal.asInstanceOf[Expr[Int]] }
147+
}.unseal
148+
}
149+
}
150+
}),
151+
DefDef(sym6_2, {
152+
case List() => {
153+
case List(List(x)) =>
154+
Some {
155+
'{
156+
println(s"sym6_2: ${ ${ x.seal.asInstanceOf[Expr[Int]] } }")
157+
if ${ x.seal.asInstanceOf[Expr[Int]] } == 0
158+
then 0
159+
else ${ Apply(Ref(sym6_1), List('{ ${ x.seal.asInstanceOf[Expr[Int]] } - 1 }.unseal)).seal.asInstanceOf[Expr[Int]] }
160+
}.unseal
161+
}
162+
}
163+
164+
}),
165+
'{ assert(${ Apply(Ref(sym6_2), List(Literal(Constant(6)))).seal.asInstanceOf[Expr[Int]] } == 0) }.unseal)
166+
167+
// test polymorphic methods by synthesizing an identity method
168+
val sym7 : Symbol = Symbol.newMethod(
169+
rootContext.owner,
170+
"sym7",
171+
PolyType(List("T"))(
172+
tp => List(TypeBounds(typeOf[Nothing], typeOf[Any])),
173+
tp => MethodType(List("t"))(
174+
_ => List(tp.param(0)),
175+
_ => tp.param(0))))
176+
assert(sym7.isDefDef)
177+
assert(sym7.name == "sym7")
178+
val sym7Statements : List[Statement] = List(
179+
DefDef(sym7, {
180+
case List(t) => {
181+
case List(List(x)) =>
182+
Some(Typed(x, Inferred(t)))
183+
}
184+
}),
185+
'{ assert(${ Apply(TypeApply(Ref(sym7), List(Inferred(typeOf[Int]))), List(Literal(Constant(7)))).seal.asInstanceOf[Expr[Int]] } == 7) }.unseal)
186+
187+
Block(
188+
sym1Statements ++
189+
sym2Statements ++
190+
sym3Statements ++
191+
sym4Statements ++
192+
sym5Statements ++
193+
sym6Statements ++
194+
sym7Statements ++
195+
List('{ println("Ok") }.unseal),
196+
Literal(Constant(()))).seal.asInstanceOf[Expr[Unit]]
197+
}
198+
}
199+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
object Test {
3+
def main(argv: Array[String]): Unit =
4+
Macros.theTestBlock
5+
}
6+

0 commit comments

Comments
 (0)