Skip to content

Commit 9e8fa84

Browse files
committed
Allow case classes to be inline arguments
1 parent fa10045 commit 9e8fa84

File tree

8 files changed

+165
-39
lines changed

8 files changed

+165
-39
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -295,12 +295,16 @@ object Splicer {
295295
case _ if tree.symbol == defn.TastyTasty_macroContext =>
296296
interpretTastyContext()
297297

298-
case StaticCall(fn, args) =>
299-
if (fn.symbol.is(Module)) {
300-
assert(args.isEmpty)
301-
interpretModuleAccess(fn)
298+
case Call(fn, args) =>
299+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
300+
interpretNew(fn, args.map(interpretTree))
301+
} else if (fn.symbol.isStatic) {
302+
if (fn.symbol.is(Module)) interpretModuleAccess(fn)
303+
else interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
304+
} else if (env.contains(fn.name)) {
305+
env(fn.name)
302306
} else {
303-
interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
307+
unexpectedTree(tree)
304308
}
305309

306310
// Interpret `foo(j = x, i = y)` which it is expanded to
@@ -313,28 +317,21 @@ object Splicer {
313317
})
314318
interpretTree(expr)(newEnv)
315319
case NamedArg(_, arg) => interpretTree(arg)
316-
case Ident(name) if env.contains(name) => env(name)
317320

318321
case Inlined(EmptyTree, Nil, expansion) => interpretTree(expansion)
319322

320-
case Apply(TypeApply(fun: RefTree, _), args) if fun.symbol.isConstructor && fun.symbol.owner.owner.is(Package) =>
321-
interpretNew(fun, args.map(interpretTree))
322-
323-
case Apply(fun: RefTree, args) if fun.symbol.isConstructor && fun.symbol.owner.owner.is(Package)=>
324-
interpretNew(fun, args.map(interpretTree))
325-
326323
case Typed(SeqLiteral(elems, _), _) =>
327324
interpretVarargs(elems.map(e => interpretTree(e)))
328325

329326
case _ =>
330327
unexpectedTree(tree)
331328
}
332329

333-
object StaticCall {
330+
object Call {
334331
def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match {
335-
case fn: RefTree if fn.symbol.isStatic => Some((fn, Nil))
336-
case Apply(StaticCall(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
337-
case TypeApply(StaticCall(fn, args), _) => Some((fn, args))
332+
case fn: RefTree => Some((fn, Nil))
333+
case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
334+
case TypeApply(Call(fn, args), _) => Some((fn, args))
338335
case _ => None
339336
}
340337
}

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -679,35 +679,28 @@ trait Checking {
679679
tree.tpe.widenTermRefExpr match {
680680
case tp: ConstantType if exprPurity(tree) >= purityLevel => // ok
681681
case tp =>
682-
def isCaseClassApply(sym: Symbol): Boolean = {
683-
sym.name == nme.apply && (
684-
tree.symbol.owner == defn.SomeClass.companionModule.moduleClass ||
685-
defn.isTupleClass(tree.symbol.owner.companionClass)
686-
)
687-
}
688-
def isCaseClassNew(sym: Symbol): Boolean = {
689-
sym.isPrimaryConstructor && (
690-
sym.eq(defn.SomeClass.primaryConstructor) ||
691-
defn.isTupleClass(tree.symbol.owner)
692-
)
693-
}
682+
def isCaseClassApply(sym: Symbol): Boolean =
683+
sym.name == nme.apply && sym.owner.is(Module) && sym.owner.companionClass.is(Case)
684+
def isCaseClassNew(sym: Symbol): Boolean =
685+
sym.isPrimaryConstructor && sym.owner.is(Case) && sym.owner.isStatic
694686
def isCaseObject(sym: Symbol): Boolean = {
695687
// TODO add alias to Nil in scala package
696-
sym.is(Case) && sym.is(Module) && sym.isStatic
688+
sym.is(Case) && sym.is(Module)
697689
}
698690
val allow =
699691
ctx.erasedTypes ||
700692
ctx.inInlineMethod ||
701-
isCaseClassApply(tree.symbol) ||
702-
isCaseClassNew(tree.symbol) ||
703-
isCaseObject(tree.symbol)
704-
if (!allow) ctx.error(em"$what must be a constant expression", tree.pos)
705-
else tree match {
706-
// TODO: add cases for type apply and multiple applies
707-
case Apply(_, args) =>
708-
for (arg <- args)
709-
checkInlineConformant(arg, isFinal, what)
710-
case _ =>
693+
(tree.symbol.isStatic && isCaseObject(tree.symbol) || isCaseClassApply(tree.symbol)) ||
694+
(tree.symbol.owner.isStatic && isCaseClassNew(tree.symbol))
695+
if (!allow) ctx.error(em"$what must be a known value", tree.pos)
696+
else {
697+
def checkArgs(tree: Tree): Unit = tree match {
698+
case Apply(fn, args) =>
699+
args.foreach(arg => checkInlineConformant(arg, isFinal, what))
700+
checkArgs(fn)
701+
case _ =>
702+
}
703+
checkArgs(tree)
711704
}
712705
}
713706
}

compiler/test/dotc/run-test-pickling.blacklist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ inline-varargs-1
2626
implicitShortcut
2727
inline-case-objects
2828
inline-option
29+
inline-macro-staged-interpreter
2930
inline-tuples-1
3031
inline-tuples-2
3132
lazy-implicit-lists.scala
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
import scala.quoted._
3+
4+
object E {
5+
6+
inline def eval[T](inline x: E[T]): T = ~impl(x)
7+
8+
def impl[T](x: E[T]): Expr[T] = x.lift
9+
10+
}
11+
12+
trait E[T] {
13+
def lift: Expr[T]
14+
}
15+
16+
case class I(n: Int) extends E[Int] {
17+
def lift: Expr[Int] = n.toExpr
18+
}
19+
20+
case class Plus[T](x: E[T], y: E[T])(implicit op: Plus2[T]) extends E[T] {
21+
def lift: Expr[T] = op(x.lift, y.lift)
22+
}
23+
24+
trait Op2[T] {
25+
def apply(x: Expr[T], y: Expr[T]): Expr[T]
26+
}
27+
28+
trait Plus2[T] extends Op2[T]
29+
object Plus2 {
30+
implicit case object IPlus extends Plus2[Int] {
31+
def apply(x: Expr[Int], y: Expr[Int]): Expr[Int] = '(~x + ~y)
32+
}
33+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
val i = I(2)
6+
E.eval(
7+
i // error
8+
)
9+
10+
E.eval(Plus(
11+
i, // error
12+
I(4)))
13+
14+
val plus = Plus2.IPlus
15+
E.eval(Plus(I(2), I(4))(
16+
plus // error
17+
))
18+
}
19+
20+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2
2+
3
3+
6
4+
7
5+
8
6+
14
7+
3.1
8+
20.4
9+
20.8
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
import scala.quoted._
3+
4+
object E {
5+
6+
inline def eval[T](inline x: E[T]): T = ~impl(x)
7+
8+
def impl[T](x: E[T]): Expr[T] = x.lift
9+
10+
}
11+
12+
trait E[T] {
13+
def lift: Expr[T]
14+
}
15+
16+
case class I(n: Int) extends E[Int] {
17+
def lift: Expr[Int] = n.toExpr
18+
}
19+
20+
case class D(n: Double) extends E[Double] {
21+
def lift: Expr[Double] = n.toExpr
22+
}
23+
24+
case class Plus[T](x: E[T], y: E[T])(implicit op: Plus2[T]) extends E[T] {
25+
def lift: Expr[T] = op(x.lift, y.lift)
26+
}
27+
28+
case class Times[T](x: E[T], y: E[T])(implicit op: Times2[T]) extends E[T] {
29+
def lift: Expr[T] = op(x.lift, y.lift)
30+
}
31+
32+
trait Op2[T] {
33+
def apply(x: Expr[T], y: Expr[T]): Expr[T]
34+
}
35+
36+
trait Plus2[T] extends Op2[T]
37+
object Plus2 {
38+
implicit case object IPlus extends Plus2[Int] {
39+
def apply(x: Expr[Int], y: Expr[Int]): Expr[Int] = '(~x + ~y)
40+
}
41+
42+
implicit case object DPlus extends Plus2[Double] {
43+
def apply(x: Expr[Double], y: Expr[Double]): Expr[Double] = '(~x + ~y)
44+
}
45+
}
46+
47+
trait Times2[T] extends Op2[T]
48+
object Times2 {
49+
implicit case object ITimes extends Times2[Int] {
50+
def apply(x: Expr[Int], y: Expr[Int]): Expr[Int] = '(~x * ~y)
51+
}
52+
53+
implicit case object DTimes extends Times2[Double] {
54+
def apply(x: Expr[Double], y: Expr[Double]): Expr[Double] = '(~x * ~y)
55+
}
56+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
println(E.eval(I(2)))
6+
println(E.eval(new I(3)))
7+
println(E.eval(Plus(I(2), I(4))))
8+
println(E.eval(new Plus(I(3), I(4))))
9+
println(E.eval(Times(I(2), I(4))))
10+
println(E.eval(Times(I(2), Plus(I(3), I(4)))))
11+
12+
println(E.eval(D(3.1)))
13+
println(E.eval(Times(D(2.4), Plus(D(3.9), D(4.6)))))
14+
println(E.eval(new Times(D(2.6), Plus(D(3.9), new D(4.1)))))
15+
}
16+
17+
}

0 commit comments

Comments
 (0)