Skip to content

Commit 1328ca3

Browse files
committed
Check type when matching primitive expressions
We might have a reference to an inline val that is not yet constant folded but is sematically equivalent to the value of its type. Fixes #11854
1 parent 85dc1cb commit 1328ca3

File tree

6 files changed

+39
-14
lines changed

6 files changed

+39
-14
lines changed

library/src/scala/quoted/FromExpr.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,14 @@ object FromExpr {
8484
def unapply(expr: Expr[T])(using Quotes) =
8585
import quotes.reflect._
8686
def rec(tree: Term): Option[T] = tree match {
87-
case Literal(c) if c.value != null => Some(c.value.asInstanceOf[T])
8887
case Block(Nil, e) => rec(e)
8988
case Typed(e, _) => rec(e)
9089
case Inlined(_, Nil, e) => rec(e)
91-
case _ => None
90+
case _ =>
91+
tree.tpe.widenTermRefByName match
92+
case ConstantType(c) =>
93+
Some(c.value.asInstanceOf[T])
94+
case _ => None
9295
}
9396
rec(expr.asTerm)
9497
}

tests/neg/i11854.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
class Bag(seq: Seq[Char])
2-
3-
inline val i = 2
4-
inline val j: Int = 2 // error
5-
inline val b: Boolean = true // error
6-
inline val s: String = "" // error
7-
inline val bagA = new Bag(Seq('a', 'b', 'c')) // error
8-
inline val bagB: Bag = new Bag(Seq('a', 'b', 'c')) // error
1+
inline val str1 = "Hello, "
2+
inline val str2 = "Scala 3"
3+
println(Str.concat(str1, str2))

tests/run-macros/expr-map-1.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ kcolb
66
neht
77
esle
88
lav
9-
vals
9+
slav
1010
fed
11-
defs
11+
sfed
1212
fed
1313
rab
1414
yrt

tests/run-macros/expr-map-1/Test_2.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object Test {
3131

3232
rewrite {
3333
val s: "vals" = "vals"
34-
println(s) // prints "foo" not "oof"
34+
println(s)
3535
}
3636

3737
rewrite {
@@ -41,7 +41,7 @@ object Test {
4141

4242
rewrite {
4343
def s: "defs" = "defs"
44-
println(s) // prints "foo" not "oof"
44+
println(s)
4545
}
4646

4747
rewrite {

tests/run-macros/i11856/Main_2.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@main def Test: Unit =
2+
inline def str1 = "Hello, "
3+
inline val str2 = "Scala 3"
4+
println(Str.concat(str1, str2))
5+
6+
inline def i1 = 1
7+
inline val i2 = 2
8+
println(I.sum(i1, i2))

tests/run-macros/i11856/Test_1.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import scala.quoted.*
2+
3+
object Str:
4+
inline def concat(inline a: String, inline b: String): String =
5+
${ evalConcat('a, 'b) }
6+
7+
def evalConcat(expra: Expr[String], exprb: Expr[String])(using Quotes): Expr[String] =
8+
val a = expra.valueOrError
9+
val b = exprb.valueOrError
10+
Expr(a ++ b)
11+
12+
object I:
13+
inline def sum(inline a: Int, inline b: Int): Int =
14+
${ evalConcat('a, 'b) }
15+
16+
def evalConcat(expra: Expr[Int], exprb: Expr[Int])(using Quotes): Expr[Int] =
17+
val a = expra.valueOrError
18+
val b = exprb.valueOrError
19+
Expr(a + b)

0 commit comments

Comments
 (0)