Skip to content

Commit bb6a6cb

Browse files
Merge pull request #5200 from dotty-staging/add-inline-case-classes
Add inline case classes as argument
2 parents 786e6da + 5a363f1 commit bb6a6cb

File tree

21 files changed

+789
-39
lines changed

21 files changed

+789
-39
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,14 @@ object Splicer {
111111
}
112112

113113
protected def interpretStaticMethodCall(fn: Tree, args: => List[Object])(implicit env: Env): Object = {
114-
if (fn.symbol == defn.NoneModuleRef.termSymbol) {
115-
// TODO generalize
116-
None
117-
} else {
118-
val (clazz, instance) = loadModule(fn.symbol.owner)
119-
val method = getMethod(clazz, fn.symbol.name, paramsSig(fn.symbol))
120-
stopIfRuntimeException(method.invoke(instance, args: _*))
121-
}
114+
val instance = loadModule(fn.symbol.owner)
115+
val method = getMethod(instance.getClass, fn.symbol.name, paramsSig(fn.symbol))
116+
stopIfRuntimeException(method.invoke(instance, args: _*))
122117
}
123118

119+
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Object =
120+
loadModule(fn.symbol.moduleClass)
121+
124122
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Object = {
125123
val clazz = loadClass(fn.symbol.owner.fullName)
126124
val constr = clazz.getConstructor(paramsSig(fn.symbol): _*)
@@ -130,24 +128,23 @@ object Splicer {
130128
protected def unexpectedTree(tree: Tree)(implicit env: Env): Object =
131129
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.pos)
132130

133-
private def loadModule(sym: Symbol): (Class[_], Object) = {
131+
private def loadModule(sym: Symbol): Object = {
134132
if (sym.owner.is(Package)) {
135133
// is top level object
136134
val moduleClass = loadClass(sym.fullName)
137-
val moduleInstance = moduleClass.getField(MODULE_INSTANCE_FIELD).get(null)
138-
(moduleClass, moduleInstance)
135+
moduleClass.getField(MODULE_INSTANCE_FIELD).get(null)
139136
} else {
140137
// nested object in an object
141138
val clazz = loadClass(sym.fullNameSeparated(FlatName))
142-
(clazz, clazz.getConstructor().newInstance().asInstanceOf[Object])
139+
clazz.getConstructor().newInstance().asInstanceOf[Object]
143140
}
144141
}
145142

146143
private def loadClass(name: Name): Class[_] = {
147144
try classLoader.loadClass(name.toString)
148145
catch {
149146
case _: ClassNotFoundException =>
150-
val msg = s"Could not find macro class $name in classpath$extraMsg"
147+
val msg = s"Could not find class $name in classpath$extraMsg"
151148
throw new StopInterpretation(msg, pos)
152149
}
153150
}
@@ -156,7 +153,7 @@ object Splicer {
156153
try clazz.getMethod(name.toString, paramClasses: _*)
157154
catch {
158155
case _: NoSuchMethodException =>
159-
val msg = em"Could not find macro method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)$extraMsg"
156+
val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)$extraMsg"
160157
throw new StopInterpretation(msg, pos)
161158
}
162159
}
@@ -257,7 +254,8 @@ object Splicer {
257254
protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
258255
protected def interpretTastyContext()(implicit env: Env): Boolean = true
259256
protected def interpretStaticMethodCall(fn: tpd.Tree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
260-
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Boolean = args.forall(identity)
257+
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Boolean = true
258+
protected def interpretNew(fn: RefTree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
261259

262260
def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Boolean = {
263261
// Assuming that top-level splices can only be in inline methods
@@ -278,6 +276,7 @@ object Splicer {
278276
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
279277
protected def interpretTastyContext()(implicit env: Env): Result
280278
protected def interpretStaticMethodCall(fn: Tree, args: => List[Result])(implicit env: Env): Result
279+
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Result
281280
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Result
282281
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
283282

@@ -294,8 +293,17 @@ object Splicer {
294293
case _ if tree.symbol == defn.TastyTasty_macroContext =>
295294
interpretTastyContext()
296295

297-
case StaticMethodCall(fn, args) =>
298-
interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
296+
case Call(fn, args) =>
297+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
298+
interpretNew(fn, args.map(interpretTree))
299+
} else if (fn.symbol.isStatic) {
300+
if (fn.symbol.is(Module)) interpretModuleAccess(fn)
301+
else interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
302+
} else if (env.contains(fn.name)) {
303+
env(fn.name)
304+
} else {
305+
unexpectedTree(tree)
306+
}
299307

300308
// Interpret `foo(j = x, i = y)` which it is expanded to
301309
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
@@ -307,28 +315,21 @@ object Splicer {
307315
})
308316
interpretTree(expr)(newEnv)
309317
case NamedArg(_, arg) => interpretTree(arg)
310-
case Ident(name) if env.contains(name) => env(name)
311318

312319
case Inlined(EmptyTree, Nil, expansion) => interpretTree(expansion)
313320

314-
case Apply(TypeApply(fun: RefTree, _), args) if fun.symbol.isConstructor && fun.symbol.owner.owner.is(Package) =>
315-
interpretNew(fun, args.map(interpretTree))
316-
317-
case Apply(fun: RefTree, args) if fun.symbol.isConstructor && fun.symbol.owner.owner.is(Package)=>
318-
interpretNew(fun, args.map(interpretTree))
319-
320321
case Typed(SeqLiteral(elems, _), _) =>
321322
interpretVarargs(elems.map(e => interpretTree(e)))
322323

323324
case _ =>
324325
unexpectedTree(tree)
325326
}
326327

327-
object StaticMethodCall {
328+
object Call {
328329
def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match {
329-
case fn: RefTree if fn.symbol.isStatic => Some((fn, Nil))
330-
case Apply(StaticMethodCall(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
331-
case TypeApply(StaticMethodCall(fn, args), _) => Some((fn, args))
330+
case fn: RefTree => Some((fn, Nil))
331+
case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
332+
case TypeApply(Call(fn, args), _) => Some((fn, args))
332333
case _ => None
333334
}
334335
}

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -679,20 +679,28 @@ trait Checking {
679679
tree.tpe.widenTermRefExpr match {
680680
case tp: ConstantType if exprPurity(tree) >= purityLevel => // ok
681681
case tp =>
682+
def isCaseClassApply(sym: Symbol): Boolean =
683+
sym.name == nme.apply && sym.is(Synthetic) && sym.owner.is(Module) && sym.owner.companionClass.is(Case)
684+
def isCaseClassNew(sym: Symbol): Boolean =
685+
sym.isPrimaryConstructor && sym.owner.is(Case) && sym.owner.isStatic
686+
def isCaseObject(sym: Symbol): Boolean = {
687+
// TODO add alias to Nil in scala package
688+
sym.is(Case) && sym.is(Module)
689+
}
682690
val allow =
683691
ctx.erasedTypes ||
684692
ctx.inInlineMethod ||
685-
// TODO: Make None and Some constant types?
686-
tree.symbol.eq(defn.NoneModuleRef.termSymbol) ||
687-
tree.symbol.eq(defn.SomeClass.primaryConstructor) ||
688-
(tree.symbol.name == nme.apply && tree.symbol.owner == defn.SomeClass.companionModule.moduleClass)
689-
if (!allow) ctx.error(em"$what must be a constant expression", tree.pos)
690-
else tree match {
691-
// TODO: add cases for type apply and multiple applies
692-
case Apply(_, args) =>
693-
for (arg <- args)
694-
checkInlineConformant(arg, isFinal, what)
695-
case _ =>
693+
(tree.symbol.isStatic && isCaseObject(tree.symbol) || isCaseClassApply(tree.symbol)) ||
694+
isCaseClassNew(tree.symbol)
695+
if (!allow) ctx.error(em"$what must be a known value", tree.pos)
696+
else {
697+
def checkArgs(tree: Tree): Unit = tree match {
698+
case Apply(fn, args) =>
699+
args.foreach(arg => checkInlineConformant(arg, isFinal, what))
700+
checkArgs(fn)
701+
case _ =>
702+
}
703+
checkArgs(tree)
696704
}
697705
}
698706
}

compiler/test/dotc/run-test-pickling.blacklist

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ i5119
2424
i5119b
2525
inline-varargs-1
2626
implicitShortcut
27+
inline-case-objects
2728
inline-option
29+
inline-macro-staged-interpreter
30+
inline-tuples-1
31+
inline-tuples-2
2832
lazy-implicit-lists.scala
2933
lazy-implicit-nums.scala
3034
lazy-traits.scala
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
import scala.quoted._
3+
4+
object Macros {
5+
def impl(foo: Any): Expr[String] = foo.getClass.getCanonicalName.toExpr
6+
}
7+
8+
class Bar {
9+
case object Baz
10+
}
11+
12+
package foo {
13+
class Bar {
14+
case object Baz
15+
}
16+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
val bar = new Bar
6+
println(fooString(bar.Baz)) // error
7+
}
8+
9+
inline def fooString(inline x: Any): String = ~Macros.impl(x)
10+
11+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
import scala.quoted._
3+
4+
object E {
5+
6+
inline def eval[T](inline x: E[T]): T = ~impl(x)
7+
8+
def impl[T](x: E[T]): Expr[T] = x.lift
9+
10+
}
11+
12+
trait E[T] {
13+
def lift: Expr[T]
14+
}
15+
16+
case class I(n: Int) extends E[Int] {
17+
def lift: Expr[Int] = n.toExpr
18+
}
19+
20+
case class Plus[T](x: E[T], y: E[T])(implicit op: Plus2[T]) extends E[T] {
21+
def lift: Expr[T] = op(x.lift, y.lift)
22+
}
23+
24+
trait Op2[T] {
25+
def apply(x: Expr[T], y: Expr[T]): Expr[T]
26+
}
27+
28+
trait Plus2[T] extends Op2[T]
29+
object Plus2 {
30+
implicit case object IPlus extends Plus2[Int] {
31+
def apply(x: Expr[Int], y: Expr[Int]): Expr[Int] = '(~x + ~y)
32+
}
33+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
val i = I(2)
6+
E.eval(
7+
i // error
8+
)
9+
10+
E.eval(Plus(
11+
i, // error
12+
I(4)))
13+
14+
val plus = Plus2.IPlus
15+
E.eval(Plus(I(2), I(4))(
16+
plus // error
17+
))
18+
}
19+
20+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
import scala.quoted._
3+
4+
object Macros {
5+
def tup1(tup: Tuple1[Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
6+
def tup2(tup: Tuple2[Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
7+
def tup3(tup: Tuple3[Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
8+
def tup4(tup: Tuple4[Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
9+
def tup5(tup: Tuple5[Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
10+
def tup6(tup: Tuple6[Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
11+
def tup7(tup: Tuple7[Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
12+
def tup8(tup: Tuple8[Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
13+
def tup9(tup: Tuple9[Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
14+
def tup10(tup: Tuple10[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
15+
def tup11(tup: Tuple11[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
16+
def tup12(tup: Tuple12[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
17+
def tup13(tup: Tuple13[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
18+
def tup14(tup: Tuple14[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
19+
def tup15(tup: Tuple15[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
20+
def tup16(tup: Tuple16[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
21+
def tup17(tup: Tuple17[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
22+
def tup18(tup: Tuple18[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
23+
def tup19(tup: Tuple19[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
24+
def tup20(tup: Tuple20[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
25+
def tup21(tup: Tuple21[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
26+
def tup22(tup: Tuple22[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
27+
}

0 commit comments

Comments
 (0)