diff --git a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala index 4633c187912f..3e09642d291d 100644 --- a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala +++ b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala @@ -29,12 +29,12 @@ object ConstFold: def Apply[T <: Apply](tree: T)(using Context): T = tree.fun match case Select(xt, op) if foldedBinops.contains(op) => - xt.tpe.widenTermRefExpr.normalized match - case ConstantType(x) => + xt match + case ConstantTree(x) => tree.args match case yt :: Nil => - yt.tpe.widenTermRefExpr.normalized match - case ConstantType(y) => tree.withFoldedType(foldBinop(op, x, y)) + yt match + case ConstantTree(y) => tree.withFoldedType(foldBinop(op, x, y)) case _ => tree case _ => tree case _ => tree @@ -46,8 +46,8 @@ object ConstFold: def Select[T <: Select](tree: T)(using Context): T = if foldedUnops.contains(tree.name) then - tree.qualifier.tpe.widenTermRefExpr.normalized match - case ConstantType(x) => tree.withFoldedType(foldUnop(tree.name, x)) + tree.qualifier match + case ConstantTree(x) => tree.withFoldedType(foldUnop(tree.name, x)) case _ => tree else tree @@ -59,6 +59,17 @@ object ConstFold: tree.withFoldedType(Constant(targ.tpe)) case _ => tree + private object ConstantTree: + def unapply(tree: Tree)(using Context): Option[Constant] = + tree match + case Inlined(_, Nil, expr) => unapply(expr) + case Typed(expr, _) => unapply(expr) + case Literal(c) if c.tag == Constants.NullTag => Some(c) + case _ => + tree.tpe.widenTermRefExpr.normalized.simplified match + case ConstantType(c) => Some(c) + case _ => None + extension [T <: Tree](tree: T)(using Context) private def withFoldedType(c: Constant | Null): T = if c == null then tree else tree.withType(ConstantType(c)).asInstanceOf[T] @@ -164,15 +175,24 @@ object ConstFold: case _ => null } private def foldStringOp(op: Name, x: Constant, y: Constant): Constant = op match { - case nme.ADD => Constant(x.stringValue + y.stringValue) + case nme.ADD => Constant(x.stringValue + y.stringValue) case nme.EQ => Constant(x.stringValue == y.stringValue) + case nme.NE => Constant(x.stringValue != y.stringValue) case _ => null } + private def foldNullOp(op: Name, x: Constant, y: Constant): Constant = + assert(x.tag == NullTag || y.tag == NullTag) + op match + case nme.EQ => Constant(x.tag == y.tag) + case nme.NE => Constant(x.tag != y.tag) + case _ => null + private def foldBinop(op: Name, x: Constant, y: Constant): Constant = val optag = if (x.tag == y.tag) x.tag else if (x.isNumeric && y.isNumeric) math.max(x.tag, y.tag) + else if (x.tag == NullTag || y.tag == NullTag) NullTag else NoTag try optag match @@ -182,6 +202,7 @@ object ConstFold: case FloatTag => foldFloatOp(op, x, y) case DoubleTag => foldDoubleOp(op, x, y) case StringTag => foldStringOp(op, x, y) + case NullTag => foldNullOp(op, x, y) case _ => null catch case ex: ArithmeticException => null // the code will crash at runtime, // but that is better than the diff --git a/tests/explicit-nulls/run/nn.scala b/tests/explicit-nulls/run/nn.scala index 3ffff69649cf..12c6c2ddb3c8 100644 --- a/tests/explicit-nulls/run/nn.scala +++ b/tests/explicit-nulls/run/nn.scala @@ -15,7 +15,7 @@ object Test { val y: String|Null = null assertThrowsNPE(y.nn) assertThrowsNPE(null.nn) - assertThrowsNPE(len(null)) + assertThrowsNPE(len(null)) assertThrowsNPE(load(null)) } } diff --git a/tests/pos-macros/i12072/Macro_1.scala b/tests/pos-macros/i12072/Macro_1.scala new file mode 100644 index 000000000000..d58160d8f00a --- /dev/null +++ b/tests/pos-macros/i12072/Macro_1.scala @@ -0,0 +1,12 @@ +import scala.quoted.* + +object M { + + transparent inline def f(inline s: String): String | Null = + ${ f('s) } + + def f(s: Expr[String])(using Quotes): Expr[String | Null] = { + s.valueOrError // required + '{ null } + } +} diff --git a/tests/pos-macros/i12072/Test_2.scala b/tests/pos-macros/i12072/Test_2.scala new file mode 100644 index 000000000000..772e21903fc5 --- /dev/null +++ b/tests/pos-macros/i12072/Test_2.scala @@ -0,0 +1,18 @@ +object T2 { + import M.f + + private inline val V = "V" + private inline def D = "D" + + trait Trait { def s: String } + + object MatchFV extends Trait { + override transparent inline def s: String = + inline f(V) match { case "V" => "o"; case _ => "x" } // error in RC1 + } + + object MatchFD extends Trait { + override transparent inline def s: String = + inline f(D) match { case "D" => "o"; case _ => "x" } + } +} diff --git a/tests/pos/i12072-b.scala b/tests/pos/i12072-b.scala new file mode 100644 index 000000000000..07bca25b68be --- /dev/null +++ b/tests/pos/i12072-b.scala @@ -0,0 +1,9 @@ +transparent inline def f: Null = null + +inline def g: Unit = + inline if f == "V" then 1 else 2 + inline if f != "V" then 3 else 4 + inline if "v" == f then 5 else 6 + inline if "v" != f then 7 else 8 + +def test = g diff --git a/tests/pos/i12072-c.scala b/tests/pos/i12072-c.scala new file mode 100644 index 000000000000..f99f0da9049f --- /dev/null +++ b/tests/pos/i12072-c.scala @@ -0,0 +1,86 @@ +object T { + + transparent inline def f(inline s: String): String | Null = + null + + inline val V = "V" + inline def D = "D" + + trait Trait { def s: String } + + // =========================================================================== + // inline {if,match} over inline {val,def} + + transparent inline def if_v: String = + inline if V == "V" then "o" else "x" + + transparent inline def if_d: String = + inline if D == "D" then "o" else "x" + + transparent inline def match_v: String = + inline V match { case "V" => "o"; case _ => "x" } + + transparent inline def match_d: String = + inline D match { case "D" => "o"; case _ => "x" } + + // =========================================================================== + // inline {if,match} over inline f(inline {val,def}) + + transparent inline def if_fv: String = + inline if f(V) == "V" then "o" else "x" + + transparent inline def if_fd: String = + inline if f(D) == "D" then "o" else "x" + + transparent inline def match_fv: String = + inline f(V) match { case "V" => "o"; case _ => "x" } + + transparent inline def match_fd: String = + inline f(D) match { case "D" => "o"; case _ => "x" } + + // =========================================================================== + // inline {if,match} over inline {val,def} in overridden method + + object IfV extends Trait { + override transparent inline def s: String = + inline if V == "V" then "o" else "x" + } + + object IfD extends Trait { + override transparent inline def s: String = + inline if D == "D" then "o" else "x" // <--------------------------- error + } + + object MatchV extends Trait { + override transparent inline def s: String = + inline V match { case "V" => "o"; case _ => "x" } + } + + object MatchD extends Trait { + override transparent inline def s: String = + inline D match { case "D" => "o"; case _ => "x" } + } + + // =========================================================================== + // inline {if,match} over inline f(inline {val,def}) in overridden method + + object IfFV extends Trait { + override transparent inline def s: String = + inline if f(V) == "V" then "o" else "x" // <------------------------ error + } + + object IfFD extends Trait { + override transparent inline def s: String = + inline if f(D) == "D" then "o" else "x" // <------------------------ error + } + + object MatchFV extends Trait { + override transparent inline def s: String = + inline f(V) match { case "V" => "o"; case _ => "x" } + } + + object MatchFD extends Trait { + override transparent inline def s: String = + inline f(D) match { case "D" => "o"; case _ => "x" } + } +} diff --git a/tests/pos/i12072-d.scala b/tests/pos/i12072-d.scala new file mode 100644 index 000000000000..486e9f478771 --- /dev/null +++ b/tests/pos/i12072-d.scala @@ -0,0 +1,4 @@ +class Test: + def n: Null = null + def test1: Boolean = n == null + def test2: Boolean = null == n diff --git a/tests/pos/i12072-e.scala b/tests/pos/i12072-e.scala new file mode 100644 index 000000000000..713c86fba909 --- /dev/null +++ b/tests/pos/i12072-e.scala @@ -0,0 +1,3 @@ +def test: Boolean = nn(42) == 42 + +def nn(x: Int): x.type & Int = ??? diff --git a/tests/pos/i12072.scala b/tests/pos/i12072.scala new file mode 100644 index 000000000000..ab85059d676a --- /dev/null +++ b/tests/pos/i12072.scala @@ -0,0 +1,8 @@ +inline def c: Int = 2 + +trait A: + def f: Unit + +class B extends A: + override inline def f: Unit = + inline if c == 2 then () else ()