Skip to content

Commit dd6fc82

Browse files
authored
Merge pull request #12080 from dotty-staging/fix-#12072
Improve constant folding logic
2 parents f3c1468 + 54708cc commit dd6fc82

File tree

9 files changed

+169
-8
lines changed

9 files changed

+169
-8
lines changed

compiler/src/dotty/tools/dotc/typer/ConstFold.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ object ConstFold:
2929
def Apply[T <: Apply](tree: T)(using Context): T =
3030
tree.fun match
3131
case Select(xt, op) if foldedBinops.contains(op) =>
32-
xt.tpe.widenTermRefExpr.normalized match
33-
case ConstantType(x) =>
32+
xt match
33+
case ConstantTree(x) =>
3434
tree.args match
3535
case yt :: Nil =>
36-
yt.tpe.widenTermRefExpr.normalized match
37-
case ConstantType(y) => tree.withFoldedType(foldBinop(op, x, y))
36+
yt match
37+
case ConstantTree(y) => tree.withFoldedType(foldBinop(op, x, y))
3838
case _ => tree
3939
case _ => tree
4040
case _ => tree
@@ -46,8 +46,8 @@ object ConstFold:
4646

4747
def Select[T <: Select](tree: T)(using Context): T =
4848
if foldedUnops.contains(tree.name) then
49-
tree.qualifier.tpe.widenTermRefExpr.normalized match
50-
case ConstantType(x) => tree.withFoldedType(foldUnop(tree.name, x))
49+
tree.qualifier match
50+
case ConstantTree(x) => tree.withFoldedType(foldUnop(tree.name, x))
5151
case _ => tree
5252
else tree
5353

@@ -59,6 +59,17 @@ object ConstFold:
5959
tree.withFoldedType(Constant(targ.tpe))
6060
case _ => tree
6161

62+
private object ConstantTree:
63+
def unapply(tree: Tree)(using Context): Option[Constant] =
64+
tree match
65+
case Inlined(_, Nil, expr) => unapply(expr)
66+
case Typed(expr, _) => unapply(expr)
67+
case Literal(c) if c.tag == Constants.NullTag => Some(c)
68+
case _ =>
69+
tree.tpe.widenTermRefExpr.normalized.simplified match
70+
case ConstantType(c) => Some(c)
71+
case _ => None
72+
6273
extension [T <: Tree](tree: T)(using Context)
6374
private def withFoldedType(c: Constant | Null): T =
6475
if c == null then tree else tree.withType(ConstantType(c)).asInstanceOf[T]
@@ -164,15 +175,24 @@ object ConstFold:
164175
case _ => null
165176
}
166177
private def foldStringOp(op: Name, x: Constant, y: Constant): Constant = op match {
167-
case nme.ADD => Constant(x.stringValue + y.stringValue)
178+
case nme.ADD => Constant(x.stringValue + y.stringValue)
168179
case nme.EQ => Constant(x.stringValue == y.stringValue)
180+
case nme.NE => Constant(x.stringValue != y.stringValue)
169181
case _ => null
170182
}
171183

184+
private def foldNullOp(op: Name, x: Constant, y: Constant): Constant =
185+
assert(x.tag == NullTag || y.tag == NullTag)
186+
op match
187+
case nme.EQ => Constant(x.tag == y.tag)
188+
case nme.NE => Constant(x.tag != y.tag)
189+
case _ => null
190+
172191
private def foldBinop(op: Name, x: Constant, y: Constant): Constant =
173192
val optag =
174193
if (x.tag == y.tag) x.tag
175194
else if (x.isNumeric && y.isNumeric) math.max(x.tag, y.tag)
195+
else if (x.tag == NullTag || y.tag == NullTag) NullTag
176196
else NoTag
177197

178198
try optag match
@@ -182,6 +202,7 @@ object ConstFold:
182202
case FloatTag => foldFloatOp(op, x, y)
183203
case DoubleTag => foldDoubleOp(op, x, y)
184204
case StringTag => foldStringOp(op, x, y)
205+
case NullTag => foldNullOp(op, x, y)
185206
case _ => null
186207
catch case ex: ArithmeticException => null // the code will crash at runtime,
187208
// but that is better than the

tests/explicit-nulls/run/nn.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ object Test {
1515
val y: String|Null = null
1616
assertThrowsNPE(y.nn)
1717
assertThrowsNPE(null.nn)
18-
assertThrowsNPE(len(null))
18+
assertThrowsNPE(len(null))
1919
assertThrowsNPE(load(null))
2020
}
2121
}

tests/pos-macros/i12072/Macro_1.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import scala.quoted.*
2+
3+
object M {
4+
5+
transparent inline def f(inline s: String): String | Null =
6+
${ f('s) }
7+
8+
def f(s: Expr[String])(using Quotes): Expr[String | Null] = {
9+
s.valueOrError // required
10+
'{ null }
11+
}
12+
}

tests/pos-macros/i12072/Test_2.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
object T2 {
2+
import M.f
3+
4+
private inline val V = "V"
5+
private inline def D = "D"
6+
7+
trait Trait { def s: String }
8+
9+
object MatchFV extends Trait {
10+
override transparent inline def s: String =
11+
inline f(V) match { case "V" => "o"; case _ => "x" } // error in RC1
12+
}
13+
14+
object MatchFD extends Trait {
15+
override transparent inline def s: String =
16+
inline f(D) match { case "D" => "o"; case _ => "x" }
17+
}
18+
}

tests/pos/i12072-b.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
transparent inline def f: Null = null
2+
3+
inline def g: Unit =
4+
inline if f == "V" then 1 else 2
5+
inline if f != "V" then 3 else 4
6+
inline if "v" == f then 5 else 6
7+
inline if "v" != f then 7 else 8
8+
9+
def test = g

tests/pos/i12072-c.scala

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
object T {
2+
3+
transparent inline def f(inline s: String): String | Null =
4+
null
5+
6+
inline val V = "V"
7+
inline def D = "D"
8+
9+
trait Trait { def s: String }
10+
11+
// ===========================================================================
12+
// inline {if,match} over inline {val,def}
13+
14+
transparent inline def if_v: String =
15+
inline if V == "V" then "o" else "x"
16+
17+
transparent inline def if_d: String =
18+
inline if D == "D" then "o" else "x"
19+
20+
transparent inline def match_v: String =
21+
inline V match { case "V" => "o"; case _ => "x" }
22+
23+
transparent inline def match_d: String =
24+
inline D match { case "D" => "o"; case _ => "x" }
25+
26+
// ===========================================================================
27+
// inline {if,match} over inline f(inline {val,def})
28+
29+
transparent inline def if_fv: String =
30+
inline if f(V) == "V" then "o" else "x"
31+
32+
transparent inline def if_fd: String =
33+
inline if f(D) == "D" then "o" else "x"
34+
35+
transparent inline def match_fv: String =
36+
inline f(V) match { case "V" => "o"; case _ => "x" }
37+
38+
transparent inline def match_fd: String =
39+
inline f(D) match { case "D" => "o"; case _ => "x" }
40+
41+
// ===========================================================================
42+
// inline {if,match} over inline {val,def} in overridden method
43+
44+
object IfV extends Trait {
45+
override transparent inline def s: String =
46+
inline if V == "V" then "o" else "x"
47+
}
48+
49+
object IfD extends Trait {
50+
override transparent inline def s: String =
51+
inline if D == "D" then "o" else "x" // <--------------------------- error
52+
}
53+
54+
object MatchV extends Trait {
55+
override transparent inline def s: String =
56+
inline V match { case "V" => "o"; case _ => "x" }
57+
}
58+
59+
object MatchD extends Trait {
60+
override transparent inline def s: String =
61+
inline D match { case "D" => "o"; case _ => "x" }
62+
}
63+
64+
// ===========================================================================
65+
// inline {if,match} over inline f(inline {val,def}) in overridden method
66+
67+
object IfFV extends Trait {
68+
override transparent inline def s: String =
69+
inline if f(V) == "V" then "o" else "x" // <------------------------ error
70+
}
71+
72+
object IfFD extends Trait {
73+
override transparent inline def s: String =
74+
inline if f(D) == "D" then "o" else "x" // <------------------------ error
75+
}
76+
77+
object MatchFV extends Trait {
78+
override transparent inline def s: String =
79+
inline f(V) match { case "V" => "o"; case _ => "x" }
80+
}
81+
82+
object MatchFD extends Trait {
83+
override transparent inline def s: String =
84+
inline f(D) match { case "D" => "o"; case _ => "x" }
85+
}
86+
}

tests/pos/i12072-d.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class Test:
2+
def n: Null = null
3+
def test1: Boolean = n == null
4+
def test2: Boolean = null == n

tests/pos/i12072-e.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
def test: Boolean = nn(42) == 42
2+
3+
def nn(x: Int): x.type & Int = ???

tests/pos/i12072.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
inline def c: Int = 2
2+
3+
trait A:
4+
def f: Unit
5+
6+
class B extends A:
7+
override inline def f: Unit =
8+
inline if c == 2 then () else ()

0 commit comments

Comments
 (0)