Skip to content

Commit 0f9a2b9

Browse files
committed
Allow case objects to be inline arguments
1 parent 3cfb4e9 commit 0f9a2b9

File tree

8 files changed

+89
-18
lines changed

8 files changed

+89
-18
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,13 @@ object Splicer {
111111
}
112112

113113
protected def interpretStaticMethodCall(fn: Tree, args: => List[Object])(implicit env: Env): Object = {
114-
if (fn.symbol == defn.NoneModuleRef.termSymbol) {
115-
// TODO generalize
116-
None
117-
} else {
118-
val (clazz, instance) = loadModule(fn.symbol.owner)
119-
val method = getMethod(clazz, fn.symbol.name, paramsSig(fn.symbol))
120-
stopIfRuntimeException(method.invoke(instance, args: _*))
121-
}
114+
val (clazz, instance) = loadModule(fn.symbol.owner)
115+
val method = getMethod(clazz, fn.symbol.name, paramsSig(fn.symbol))
116+
stopIfRuntimeException(method.invoke(instance, args: _*))
117+
}
118+
119+
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Object = {
120+
loadModule(fn.symbol.moduleClass)._2
122121
}
123122

124123
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Object = {
@@ -147,7 +146,7 @@ object Splicer {
147146
try classLoader.loadClass(name.toString)
148147
catch {
149148
case _: ClassNotFoundException =>
150-
val msg = s"Could not find macro class $name in classpath$extraMsg"
149+
val msg = s"Could not find class $name in classpath$extraMsg"
151150
throw new StopInterpretation(msg, pos)
152151
}
153152
}
@@ -156,7 +155,7 @@ object Splicer {
156155
try clazz.getMethod(name.toString, paramClasses: _*)
157156
catch {
158157
case _: NoSuchMethodException =>
159-
val msg = em"Could not find macro method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)$extraMsg"
158+
val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)$extraMsg"
160159
throw new StopInterpretation(msg, pos)
161160
}
162161
}
@@ -257,7 +256,8 @@ object Splicer {
257256
protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
258257
protected def interpretTastyContext()(implicit env: Env): Boolean = true
259258
protected def interpretStaticMethodCall(fn: tpd.Tree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
260-
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Boolean = args.forall(identity)
259+
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Boolean = true
260+
protected def interpretNew(fn: RefTree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
261261

262262
def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Boolean = {
263263
// Assuming that top-level splices can only be in inline methods
@@ -278,6 +278,7 @@ object Splicer {
278278
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
279279
protected def interpretTastyContext()(implicit env: Env): Result
280280
protected def interpretStaticMethodCall(fn: Tree, args: => List[Result])(implicit env: Env): Result
281+
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Result
281282
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Result
282283
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
283284

@@ -294,8 +295,13 @@ object Splicer {
294295
case _ if tree.symbol == defn.TastyTasty_macroContext =>
295296
interpretTastyContext()
296297

297-
case StaticMethodCall(fn, args) =>
298-
interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
298+
case StaticCall(fn, args) =>
299+
if (fn.symbol.is(Module)) {
300+
assert(args.isEmpty)
301+
interpretModuleAccess(fn)
302+
} else {
303+
interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
304+
}
299305

300306
// Interpret `foo(j = x, i = y)` which it is expanded to
301307
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
@@ -324,11 +330,11 @@ object Splicer {
324330
unexpectedTree(tree)
325331
}
326332

327-
object StaticMethodCall {
333+
object StaticCall {
328334
def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match {
329335
case fn: RefTree if fn.symbol.isStatic => Some((fn, Nil))
330-
case Apply(StaticMethodCall(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
331-
case TypeApply(StaticMethodCall(fn, args), _) => Some((fn, args))
336+
case Apply(StaticCall(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
337+
case TypeApply(StaticCall(fn, args), _) => Some((fn, args))
332338
case _ => None
333339
}
334340
}

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,6 @@ trait Checking {
679679
tree.tpe.widenTermRefExpr match {
680680
case tp: ConstantType if exprPurity(tree) >= purityLevel => // ok
681681
case tp =>
682-
// TODO: Make None and Some constant types?
683682
def isCaseClassApply(sym: Symbol): Boolean = {
684683
sym.name == nme.apply && (
685684
tree.symbol.owner == defn.SomeClass.companionModule.moduleClass ||
@@ -693,7 +692,8 @@ trait Checking {
693692
)
694693
}
695694
def isCaseObject(sym: Symbol): Boolean = {
696-
tree.symbol.eq(defn.NoneModuleRef.termSymbol)
695+
// TODO add alias to Nil in scala package
696+
sym.is(Case) && sym.is(Module) && sym.isStatic
697697
}
698698
val allow =
699699
ctx.erasedTypes ||

compiler/test/dotc/run-test-pickling.blacklist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ i5119
2424
i5119b
2525
inline-varargs-1
2626
implicitShortcut
27+
inline-case-objects
2728
inline-option
2829
inline-tuples-1
2930
inline-tuples-2
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
import scala.quoted._
3+
4+
object Macros {
5+
def impl(foo: Any): Expr[String] = foo.getClass.getCanonicalName.toExpr
6+
}
7+
8+
class Bar {
9+
case object Baz
10+
}
11+
12+
package foo {
13+
class Bar {
14+
case object Baz
15+
}
16+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
val bar = new Bar
6+
println(fooString(bar.Baz)) // error
7+
}
8+
9+
inline def fooString(inline x: Any): String = ~Macros.impl(x)
10+
11+
}

tests/run/inline-case-objects.check

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
scala.collection.immutable.Nil$
2+
scala.None$
3+
Bar$
4+
Bar.Baz$
5+
foo.Bar$
6+
Bar.Baz$
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
import scala.quoted._
3+
4+
object Macros {
5+
def impl(foo: Any): Expr[String] = foo.getClass.getCanonicalName.toExpr
6+
}
7+
8+
case object Bar {
9+
case object Baz
10+
}
11+
12+
package foo {
13+
case object Bar {
14+
case object Baz
15+
}
16+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
println(fooString(scala.collection.immutable.Nil))
6+
println(fooString(None))
7+
println(fooString(Bar))
8+
println(fooString(Bar.Baz))
9+
println(fooString(foo.Bar))
10+
println(fooString(foo.Bar.Baz))
11+
}
12+
13+
inline def fooString(inline x: Any): String = ~Macros.impl(x)
14+
15+
}

0 commit comments

Comments
 (0)