Skip to content

Commit ca46f64

Browse files
committed
Improve constant folding logic
* Better handling of inlined expression * Constant fold `String.!=` * Handle `null` in `==` and `!=` Fixes scala#12072
1 parent 85a03ee commit ca46f64

File tree

9 files changed

+170
-8
lines changed

9 files changed

+170
-8
lines changed

compiler/src/dotty/tools/dotc/typer/ConstFold.scala

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ object ConstFold:
2929
def Apply[T <: Apply](tree: T)(using Context): T =
3030
tree.fun match
3131
case Select(xt, op) if foldedBinops.contains(op) =>
32-
xt.tpe.widenTermRefExpr.normalized match
33-
case ConstantType(x) =>
32+
treeConstant(xt) match
33+
case Some(x) =>
3434
tree.args match
3535
case yt :: Nil =>
36-
yt.tpe.widenTermRefExpr.normalized match
37-
case ConstantType(y) => tree.withFoldedType(foldBinop(op, x, y))
36+
treeConstant(yt) match
37+
case Some(y) => tree.withFoldedType(foldBinop(op, x, y))
3838
case _ => tree
3939
case _ => tree
4040
case _ => tree
@@ -46,8 +46,8 @@ object ConstFold:
4646

4747
def Select[T <: Select](tree: T)(using Context): T =
4848
if foldedUnops.contains(tree.name) then
49-
tree.qualifier.tpe.widenTermRefExpr.normalized match
50-
case ConstantType(x) => tree.withFoldedType(foldUnop(tree.name, x))
49+
treeConstant(tree.qualifier) match
50+
case Some(x) => tree.withFoldedType(foldUnop(tree.name, x))
5151
case _ => tree
5252
else tree
5353

@@ -59,6 +59,16 @@ object ConstFold:
5959
tree.withFoldedType(Constant(targ.tpe))
6060
case _ => tree
6161

62+
private def treeConstant(tree: Tree)(using Context): Option[Constant] =
63+
tree match
64+
case Inlined(_, Nil, expr) => treeConstant(expr)
65+
case Typed(expr, _) => treeConstant(expr)
66+
case Literal(c) if c.tag == Constants.NullTag => Some(c)
67+
case _ =>
68+
tree.tpe.widenTermRefExpr.normalized.simplified match
69+
case ConstantType(c) => Some(c)
70+
case _ => None
71+
6272
extension [T <: Tree](tree: T)(using Context)
6373
private def withFoldedType(c: Constant | Null): T =
6474
if c == null then tree else tree.withType(ConstantType(c)).asInstanceOf[T]
@@ -164,15 +174,24 @@ object ConstFold:
164174
case _ => null
165175
}
166176
private def foldStringOp(op: Name, x: Constant, y: Constant): Constant = op match {
167-
case nme.ADD => Constant(x.stringValue + y.stringValue)
177+
case nme.ADD => Constant(x.stringValue + y.stringValue)
168178
case nme.EQ => Constant(x.stringValue == y.stringValue)
179+
case nme.NE => Constant(x.stringValue != y.stringValue)
169180
case _ => null
170181
}
171182

183+
private def foldNullOp(op: Name, x: Constant, y: Constant): Constant =
184+
assert(x.tag == NullTag || y.tag == NullTag)
185+
op match
186+
case nme.EQ => Constant(x.tag == y.tag)
187+
case nme.NE => Constant(x.tag != y.tag)
188+
case _ => null
189+
172190
private def foldBinop(op: Name, x: Constant, y: Constant): Constant =
173191
val optag =
174192
if (x.tag == y.tag) x.tag
175193
else if (x.isNumeric && y.isNumeric) math.max(x.tag, y.tag)
194+
else if (x.tag == NullTag || y.tag == NullTag) NullTag
176195
else NoTag
177196

178197
try optag match
@@ -182,6 +201,7 @@ object ConstFold:
182201
case FloatTag => foldFloatOp(op, x, y)
183202
case DoubleTag => foldDoubleOp(op, x, y)
184203
case StringTag => foldStringOp(op, x, y)
204+
case NullTag => foldNullOp(op, x, y)
185205
case _ => null
186206
catch case ex: ArithmeticException => null // the code will crash at runtime,
187207
// but that is better than the

tests/explicit-nulls/run/nn.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ object Test {
1515
val y: String|Null = null
1616
assertThrowsNPE(y.nn)
1717
assertThrowsNPE(null.nn)
18-
assertThrowsNPE(len(null))
18+
assertThrowsNPE(len(null))
1919
assertThrowsNPE(load(null))
2020
}
2121
}

tests/pos-macros/i12072/Macro_1.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import scala.quoted.*
2+
3+
object M {
4+
5+
transparent inline def f(inline s: String): String | Null =
6+
${ f('s) }
7+
8+
def f(s: Expr[String])(using Quotes): Expr[String | Null] = {
9+
s.valueOrError // required
10+
'{ null }
11+
}
12+
}

tests/pos-macros/i12072/Test_2.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
object T2 {
2+
import M.f
3+
4+
private inline val V = "V"
5+
private inline def D = "D"
6+
7+
trait Trait { def s: String }
8+
9+
object MatchFV extends Trait {
10+
override transparent inline def s: String =
11+
inline f(V) match { case "V" => "o"; case _ => "x" } // error in RC1
12+
}
13+
14+
object MatchFD extends Trait {
15+
override transparent inline def s: String =
16+
inline f(D) match { case "D" => "o"; case _ => "x" }
17+
}
18+
}

tests/pos/i12072-b.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
transparent inline def f: Null = null
2+
3+
inline def g: Unit =
4+
inline if f == "V" then 1 else 2
5+
inline if f != "V" then 3 else 4
6+
inline if "v" == f then 5 else 6
7+
inline if "v" != f then 7 else 8
8+
9+
def test = g

tests/pos/i12072-c.scala

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
object T {
2+
3+
transparent inline def f(inline s: String): String | Null =
4+
null
5+
6+
inline val V = "V"
7+
inline def D = "D"
8+
9+
trait Trait { def s: String }
10+
11+
// ===========================================================================
12+
// inline {if,match} over inline {val,def}
13+
14+
transparent inline def if_v: String =
15+
inline if V == "V" then "o" else "x"
16+
17+
transparent inline def if_d: String =
18+
inline if D == "D" then "o" else "x"
19+
20+
transparent inline def match_v: String =
21+
inline V match { case "V" => "o"; case _ => "x" }
22+
23+
transparent inline def match_d: String =
24+
inline D match { case "D" => "o"; case _ => "x" }
25+
26+
// ===========================================================================
27+
// inline {if,match} over inline f(inline {val,def})
28+
29+
transparent inline def if_fv: String =
30+
inline if f(V) == "V" then "o" else "x"
31+
32+
transparent inline def if_fd: String =
33+
inline if f(D) == "D" then "o" else "x"
34+
35+
transparent inline def match_fv: String =
36+
inline f(V) match { case "V" => "o"; case _ => "x" }
37+
38+
transparent inline def match_fd: String =
39+
inline f(D) match { case "D" => "o"; case _ => "x" }
40+
41+
// ===========================================================================
42+
// inline {if,match} over inline {val,def} in overridden method
43+
44+
object IfV extends Trait {
45+
override transparent inline def s: String =
46+
inline if V == "V" then "o" else "x"
47+
}
48+
49+
object IfD extends Trait {
50+
override transparent inline def s: String =
51+
inline if D == "D" then "o" else "x" // <--------------------------- error
52+
}
53+
54+
object MatchV extends Trait {
55+
override transparent inline def s: String =
56+
inline V match { case "V" => "o"; case _ => "x" }
57+
}
58+
59+
object MatchD extends Trait {
60+
override transparent inline def s: String =
61+
inline D match { case "D" => "o"; case _ => "x" }
62+
}
63+
64+
// ===========================================================================
65+
// inline {if,match} over inline f(inline {val,def}) in overridden method
66+
67+
object IfFV extends Trait {
68+
override transparent inline def s: String =
69+
inline if f(V) == "V" then "o" else "x" // <------------------------ error
70+
}
71+
72+
object IfFD extends Trait {
73+
override transparent inline def s: String =
74+
inline if f(D) == "D" then "o" else "x" // <------------------------ error
75+
}
76+
77+
object MatchFV extends Trait {
78+
override transparent inline def s: String =
79+
inline f(V) match { case "V" => "o"; case _ => "x" }
80+
}
81+
82+
object MatchFD extends Trait {
83+
override transparent inline def s: String =
84+
inline f(D) match { case "D" => "o"; case _ => "x" }
85+
}
86+
}

tests/pos/i12072-d.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class Test:
2+
def n: Null = null
3+
def test1: Boolean = n == null
4+
def test2: Boolean = null == n

tests/pos/i12072-e.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
def test: Boolean = nn(42) == 42
2+
3+
def nn(x: Int): x.type & Int = ???

tests/pos/i12072.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
3+
inline def c: Int = 2
4+
5+
trait A:
6+
def f: Unit
7+
8+
class B extends A:
9+
override inline def f: Unit =
10+
inline if c == 2 then () else ()

0 commit comments

Comments
 (0)