Skip to content

Constant folding improvements #3466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions compiler/src/dotty/tools/dotc/transform/localopt/ConstantFold.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ import Simplify.desugarIdent
* out (nested) if with equivalent branches wrt to isSimilar. For example:
* - if (b) exp else exp → b; exp
* - if (b1) e1 else if (b2) e1 else e2 → if (b1 || b2) e1 else e2
* - if(!b) e1 else e2 → if(b) e2 else e1
*
* - Constant propagation over pattern matching.
*
* @author DarkDimius, OlivierBlanvillain
* @author DarkDimius, OlivierBlanvillain, gan74
*/
class ConstantFold(val simplifyPhase: Simplify) extends Optimisation {
import ast.tpd._

def visitor(implicit ctx: Context) = NoVisitor
def clear(): Unit = ()

def transformer(implicit ctx: Context): Tree => Tree = { x => preEval(x) match {
def transformer(implicit ctx: Context): Tree => Tree = {
// TODO: include handling of isInstanceOf similar to one in IsInstanceOfEvaluator
// TODO: include methods such as Int.int2double(see ./tests/pos/harmonize.scala)
case If(cond1, thenp, elsep) if isSimilar(thenp, elsep) =>
Expand Down Expand Up @@ -75,7 +76,7 @@ import Simplify.desugarIdent
// isBool(ift.tpe) && !elsep.const.booleanValue =>
// cond.select(defn.Boolean_&&).appliedTo(elsep)
// the other case ins't handled intentionally. See previous case for explanation

case If(t @ Select(recv, _), thenp, elsep) if t.symbol eq defn.Boolean_! =>
If(recv, elsep, thenp)

Expand Down Expand Up @@ -141,6 +142,15 @@ import Simplify.desugarIdent
// Block(List(lhs),
// ref(defn.throwMethod).appliedTo(New(defn.ArithmeticExceptionClass.typeRef, defn.ArithmeticExceptionClass_stringConstructor, Literal(Constant("/ by zero")) :: Nil)))

case (l: Literal, r: Literal) =>
(l.tpe.widenTermRefExpr, r.tpe.widenTermRefExpr) match {
case (ConstantType(_), ConstantType(_)) =>
val s = ConstFold.apply(t)
if ((s ne null) && s.tpe.isInstanceOf[ConstantType]) Literal(s.tpe.asInstanceOf[ConstantType].value)
else t
case _ => t
}

case _ => t
}

Expand All @@ -157,26 +167,10 @@ import Simplify.desugarIdent

case t: Literal => t
case t: CaseDef => t
case t if !isPureExpr(t) => t
case t =>
val s = ConstFold.apply(t)
if ((s ne null) && s.tpe.isInstanceOf[ConstantType]) {
val constant = s.tpe.asInstanceOf[ConstantType].value
Literal(constant)
} else t
}
case t => t
}

def preEval(t: Tree)(implicit ctx: Context) = {
if (t.isInstanceOf[Literal] || t.isInstanceOf[CaseDef] || !isPureExpr(t)) t
else {
val s = ConstFold.apply(t)
if ((s ne null) && s.tpe.isInstanceOf[ConstantType]) {
val constant = s.tpe.asInstanceOf[ConstantType].value
Literal(constant)
} else t
}
}


def isSimilar(t1: Tree, t2: Tree)(implicit ctx: Context): Boolean = t1 match {
case t1: Apply =>
Expand Down
62 changes: 51 additions & 11 deletions compiler/test/dotty/tools/dotc/SimplifyTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,6 @@ abstract class SimplifyTests(val optimise: Boolean) extends DottyBytecodeTest {
|print(new Some(new Tuple2(1, "s")))
""")

@Test def constantFold =
check(
"""
|val t = true // val needed, or typer takes care of this
|if (t) print(1)
|else print(2)
""",
"""
|print(1)
""")

@Test def dropNoEffects =
check(
"""
Expand Down Expand Up @@ -124,6 +113,57 @@ abstract class SimplifyTests(val optimise: Boolean) extends DottyBytecodeTest {
|println(true)
""")


/*
* Constant folding tests
*/

@Test def basicConstantFold =
check(
"""
|val i = 3
|val j = i + 4
|print(j)
""",
"""
|print(7)
""")

@Test def branchConstantFold =
check(
"""
|val t = true // val needed, or typer takes care of this
|if (t) print(1)
|else print(2)
""",
"""
|print(1)
""")

@Test def arithmeticConstantFold =
check(
"""
|val i = 3
|val j = i + 4
|if(j - i >= (i + 1) / 2)
| print(i + 1)
""",
"""
|print(4)
""")

@Test def twoValConstantFold =
check(
"""
|val i = 3
|val j = 4
|val k = i * j
|print(k - j)
""",
"""
|print(8)
""")

// @Test def listPatmapExample =
// check(
// """
Expand Down