Skip to content

Improve constant folding logic #12080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/ConstFold.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ object ConstFold:
def Apply[T <: Apply](tree: T)(using Context): T =
tree.fun match
case Select(xt, op) if foldedBinops.contains(op) =>
xt.tpe.widenTermRefExpr.normalized match
case ConstantType(x) =>
xt match
case ConstantTree(x) =>
tree.args match
case yt :: Nil =>
yt.tpe.widenTermRefExpr.normalized match
case ConstantType(y) => tree.withFoldedType(foldBinop(op, x, y))
yt match
case ConstantTree(y) => tree.withFoldedType(foldBinop(op, x, y))
case _ => tree
case _ => tree
case _ => tree
Expand All @@ -46,8 +46,8 @@ object ConstFold:

def Select[T <: Select](tree: T)(using Context): T =
if foldedUnops.contains(tree.name) then
tree.qualifier.tpe.widenTermRefExpr.normalized match
case ConstantType(x) => tree.withFoldedType(foldUnop(tree.name, x))
tree.qualifier match
case ConstantTree(x) => tree.withFoldedType(foldUnop(tree.name, x))
case _ => tree
else tree

Expand All @@ -59,6 +59,17 @@ object ConstFold:
tree.withFoldedType(Constant(targ.tpe))
case _ => tree

private object ConstantTree:
def unapply(tree: Tree)(using Context): Option[Constant] =
tree match
case Inlined(_, Nil, expr) => unapply(expr)
case Typed(expr, _) => unapply(expr)
case Literal(c) if c.tag == Constants.NullTag => Some(c)
case _ =>
tree.tpe.widenTermRefExpr.normalized.simplified match
case ConstantType(c) => Some(c)
case _ => None

extension [T <: Tree](tree: T)(using Context)
private def withFoldedType(c: Constant | Null): T =
if c == null then tree else tree.withType(ConstantType(c)).asInstanceOf[T]
Expand Down Expand Up @@ -164,15 +175,24 @@ object ConstFold:
case _ => null
}
private def foldStringOp(op: Name, x: Constant, y: Constant): Constant = op match {
case nme.ADD => Constant(x.stringValue + y.stringValue)
case nme.ADD => Constant(x.stringValue + y.stringValue)
case nme.EQ => Constant(x.stringValue == y.stringValue)
case nme.NE => Constant(x.stringValue != y.stringValue)
case _ => null
}

private def foldNullOp(op: Name, x: Constant, y: Constant): Constant =
assert(x.tag == NullTag || y.tag == NullTag)
op match
case nme.EQ => Constant(x.tag == y.tag)
case nme.NE => Constant(x.tag != y.tag)
case _ => null

private def foldBinop(op: Name, x: Constant, y: Constant): Constant =
val optag =
if (x.tag == y.tag) x.tag
else if (x.isNumeric && y.isNumeric) math.max(x.tag, y.tag)
else if (x.tag == NullTag || y.tag == NullTag) NullTag
else NoTag

try optag match
Expand All @@ -182,6 +202,7 @@ object ConstFold:
case FloatTag => foldFloatOp(op, x, y)
case DoubleTag => foldDoubleOp(op, x, y)
case StringTag => foldStringOp(op, x, y)
case NullTag => foldNullOp(op, x, y)
case _ => null
catch case ex: ArithmeticException => null // the code will crash at runtime,
// but that is better than the
Expand Down
2 changes: 1 addition & 1 deletion tests/explicit-nulls/run/nn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object Test {
val y: String|Null = null
assertThrowsNPE(y.nn)
assertThrowsNPE(null.nn)
assertThrowsNPE(len(null))
assertThrowsNPE(len(null))
assertThrowsNPE(load(null))
}
}
12 changes: 12 additions & 0 deletions tests/pos-macros/i12072/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import scala.quoted.*

object M {

transparent inline def f(inline s: String): String | Null =
${ f('s) }

def f(s: Expr[String])(using Quotes): Expr[String | Null] = {
s.valueOrError // required
'{ null }
}
}
18 changes: 18 additions & 0 deletions tests/pos-macros/i12072/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
object T2 {
import M.f

private inline val V = "V"
private inline def D = "D"

trait Trait { def s: String }

object MatchFV extends Trait {
override transparent inline def s: String =
inline f(V) match { case "V" => "o"; case _ => "x" } // error in RC1
}

object MatchFD extends Trait {
override transparent inline def s: String =
inline f(D) match { case "D" => "o"; case _ => "x" }
}
}
9 changes: 9 additions & 0 deletions tests/pos/i12072-b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
transparent inline def f: Null = null

inline def g: Unit =
inline if f == "V" then 1 else 2
inline if f != "V" then 3 else 4
inline if "v" == f then 5 else 6
inline if "v" != f then 7 else 8

def test = g
86 changes: 86 additions & 0 deletions tests/pos/i12072-c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
object T {

transparent inline def f(inline s: String): String | Null =
null

inline val V = "V"
inline def D = "D"

trait Trait { def s: String }

// ===========================================================================
// inline {if,match} over inline {val,def}

transparent inline def if_v: String =
inline if V == "V" then "o" else "x"

transparent inline def if_d: String =
inline if D == "D" then "o" else "x"

transparent inline def match_v: String =
inline V match { case "V" => "o"; case _ => "x" }

transparent inline def match_d: String =
inline D match { case "D" => "o"; case _ => "x" }

// ===========================================================================
// inline {if,match} over inline f(inline {val,def})

transparent inline def if_fv: String =
inline if f(V) == "V" then "o" else "x"

transparent inline def if_fd: String =
inline if f(D) == "D" then "o" else "x"

transparent inline def match_fv: String =
inline f(V) match { case "V" => "o"; case _ => "x" }

transparent inline def match_fd: String =
inline f(D) match { case "D" => "o"; case _ => "x" }

// ===========================================================================
// inline {if,match} over inline {val,def} in overridden method

object IfV extends Trait {
override transparent inline def s: String =
inline if V == "V" then "o" else "x"
}

object IfD extends Trait {
override transparent inline def s: String =
inline if D == "D" then "o" else "x" // <--------------------------- error
}

object MatchV extends Trait {
override transparent inline def s: String =
inline V match { case "V" => "o"; case _ => "x" }
}

object MatchD extends Trait {
override transparent inline def s: String =
inline D match { case "D" => "o"; case _ => "x" }
}

// ===========================================================================
// inline {if,match} over inline f(inline {val,def}) in overridden method

object IfFV extends Trait {
override transparent inline def s: String =
inline if f(V) == "V" then "o" else "x" // <------------------------ error
}

object IfFD extends Trait {
override transparent inline def s: String =
inline if f(D) == "D" then "o" else "x" // <------------------------ error
}

object MatchFV extends Trait {
override transparent inline def s: String =
inline f(V) match { case "V" => "o"; case _ => "x" }
}

object MatchFD extends Trait {
override transparent inline def s: String =
inline f(D) match { case "D" => "o"; case _ => "x" }
}
}
4 changes: 4 additions & 0 deletions tests/pos/i12072-d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Test:
def n: Null = null
def test1: Boolean = n == null
def test2: Boolean = null == n
3 changes: 3 additions & 0 deletions tests/pos/i12072-e.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def test: Boolean = nn(42) == 42

def nn(x: Int): x.type & Int = ???
8 changes: 8 additions & 0 deletions tests/pos/i12072.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
inline def c: Int = 2

trait A:
def f: Unit

class B extends A:
override inline def f: Unit =
inline if c == 2 then () else ()