Skip to content

Commit 90b2a5c

Browse files
committed
Allow case classes to be inline arguments
1 parent 0f9a2b9 commit 90b2a5c

File tree

7 files changed

+141
-13
lines changed

7 files changed

+141
-13
lines changed

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -679,12 +679,8 @@ trait Checking {
679679
tree.tpe.widenTermRefExpr match {
680680
case tp: ConstantType if exprPurity(tree) >= purityLevel => // ok
681681
case tp =>
682-
def isCaseClassApply(sym: Symbol): Boolean = {
683-
sym.name == nme.apply && (
684-
tree.symbol.owner == defn.SomeClass.companionModule.moduleClass ||
685-
defn.isTupleClass(tree.symbol.owner.companionClass)
686-
)
687-
}
682+
def isCaseClassApply(sym: Symbol): Boolean =
683+
sym.name == nme.apply && sym.isStatic && sym.owner.is(Module) && sym.owner.companionClass.is(Case)
688684
def isCaseClassNew(sym: Symbol): Boolean = {
689685
sym.isPrimaryConstructor && (
690686
sym.eq(defn.SomeClass.primaryConstructor) ||
@@ -701,13 +697,15 @@ trait Checking {
701697
isCaseClassApply(tree.symbol) ||
702698
isCaseClassNew(tree.symbol) ||
703699
isCaseObject(tree.symbol)
704-
if (!allow) ctx.error(em"$what must be a constant expression", tree.pos)
705-
else tree match {
706-
// TODO: add cases for type apply and multiple applies
707-
case Apply(_, args) =>
708-
for (arg <- args)
709-
checkInlineConformant(arg, isFinal, what)
710-
case _ =>
700+
if (!allow) ctx.error(em"$what must be a known value", tree.pos)
701+
else {
702+
def checkArgs(tree: Tree): Unit = tree match {
703+
case Apply(fn, args) =>
704+
args.foreach(arg => checkInlineConformant(arg, isFinal, what))
705+
checkArgs(fn)
706+
case _ =>
707+
}
708+
checkArgs(tree)
711709
}
712710
}
713711
}

compiler/test/dotc/run-test-pickling.blacklist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ inline-varargs-1
2626
implicitShortcut
2727
inline-case-objects
2828
inline-option
29+
inline-macro-staged-interpreter
2930
inline-tuples-1
3031
inline-tuples-2
3132
lazy-implicit-lists.scala
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
import scala.quoted._
3+
4+
object E {
5+
6+
inline def eval[T](inline x: E[T]): T = ~impl(x)
7+
8+
def impl[T](x: E[T]): Expr[T] = x.lift
9+
10+
}
11+
12+
trait E[T] {
13+
def lift: Expr[T]
14+
}
15+
16+
case class I(n: Int) extends E[Int] {
17+
def lift: Expr[Int] = n.toExpr
18+
}
19+
20+
case class Plus[T](x: E[T], y: E[T])(implicit op: Plus2[T]) extends E[T] {
21+
def lift: Expr[T] = op(x.lift, y.lift)
22+
}
23+
24+
trait Op2[T] {
25+
def apply(x: Expr[T], y: Expr[T]): Expr[T]
26+
}
27+
28+
trait Plus2[T] extends Op2[T]
29+
object Plus2 {
30+
implicit case object IPlus extends Plus2[Int] {
31+
def apply(x: Expr[Int], y: Expr[Int]): Expr[Int] = '(~x + ~y)
32+
}
33+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
val i = I(2)
6+
E.eval(
7+
i // error
8+
)
9+
10+
E.eval(Plus(
11+
i, // error
12+
I(4)))
13+
14+
val plus = Plus2.IPlus
15+
E.eval(Plus(I(2), I(4))(
16+
plus // error
17+
))
18+
}
19+
20+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2
2+
6
3+
8
4+
14
5+
3.1
6+
20.4
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
import scala.quoted._
3+
4+
object E {
5+
6+
inline def eval[T](inline x: E[T]): T = ~impl(x)
7+
8+
def impl[T](x: E[T]): Expr[T] = x.lift
9+
10+
}
11+
12+
trait E[T] {
13+
def lift: Expr[T]
14+
}
15+
16+
case class I(n: Int) extends E[Int] {
17+
def lift: Expr[Int] = n.toExpr
18+
}
19+
20+
case class D(n: Double) extends E[Double] {
21+
def lift: Expr[Double] = n.toExpr
22+
}
23+
24+
case class Plus[T](x: E[T], y: E[T])(implicit op: Plus2[T]) extends E[T] {
25+
def lift: Expr[T] = op(x.lift, y.lift)
26+
}
27+
28+
case class Times[T](x: E[T], y: E[T])(implicit op: Times2[T]) extends E[T] {
29+
def lift: Expr[T] = op(x.lift, y.lift)
30+
}
31+
32+
trait Op2[T] {
33+
def apply(x: Expr[T], y: Expr[T]): Expr[T]
34+
}
35+
36+
trait Plus2[T] extends Op2[T]
37+
object Plus2 {
38+
implicit case object IPlus extends Plus2[Int] {
39+
def apply(x: Expr[Int], y: Expr[Int]): Expr[Int] = '(~x + ~y)
40+
}
41+
42+
implicit case object DPlus extends Plus2[Double] {
43+
def apply(x: Expr[Double], y: Expr[Double]): Expr[Double] = '(~x + ~y)
44+
}
45+
}
46+
47+
trait Times2[T] extends Op2[T]
48+
object Times2 {
49+
implicit case object ITimes extends Times2[Int] {
50+
def apply(x: Expr[Int], y: Expr[Int]): Expr[Int] = '(~x * ~y)
51+
}
52+
53+
implicit case object DTimes extends Times2[Double] {
54+
def apply(x: Expr[Double], y: Expr[Double]): Expr[Double] = '(~x * ~y)
55+
}
56+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
object Test {
3+
4+
def main(args: Array[String]): Unit = {
5+
println(E.eval(I(2)))
6+
println(E.eval(Plus(I(2), I(4))))
7+
println(E.eval(Times(I(2), I(4))))
8+
println(E.eval(Times(I(2), Plus(I(3), I(4)))))
9+
10+
println(E.eval(D(3.1)))
11+
println(E.eval(Times(D(2.4), Plus(D(3.9), D(4.6)))))
12+
}
13+
14+
}

0 commit comments

Comments
 (0)