Skip to content

Commit b86af07

Browse files
committed
Fix #9246 and #6800: Remove pure statements
1 parent 8c5a58f commit b86af07

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class Compiler {
9595
new ArrayConstructors) :: // Intercept creation of (non-generic) arrays and intrinsify.
9696
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
9797
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
98+
new PureStats, // Remove pure stats from blocks
9899
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
99100
new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
100101
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import ast.{Trees, tpd}
5+
import core._, core.Decorators._
6+
import MegaPhase._
7+
import Types._, Contexts._, Flags._, DenotTransformers._
8+
import Symbols._, StdNames._, Trees._
9+
10+
object PureStats {
11+
val name: String = "pureStats"
12+
}
13+
14+
/** Remove pure statements in blocks */
15+
class PureStats extends MiniPhase {
16+
17+
import tpd._
18+
19+
override def phaseName: String = PureStats.name
20+
21+
override def runsAfter: Set[String] = Set(Erasure.name)
22+
23+
override def transformBlock(tree: Block)(implicit ctx: Context): Tree =
24+
val stats = Trees.flatten(
25+
tree.stats.mapConserve {
26+
case Typed(Block(stats, expr), _) if isPureExpr(expr) => Thicket(stats)
27+
case stat if !stat.symbol.isConstructor && isPureExpr(stat) => EmptyTree
28+
case stat => stat
29+
})
30+
cpy.Block(tree)(stats, tree.expr)
31+
32+
}

compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,4 +418,70 @@ class InlineBytecodeTests extends DottyBytecodeTest {
418418

419419
}
420420
}
421+
422+
@Test def i6800a = {
423+
val source = """class Foo:
424+
| inline def inlined(f: => Unit): Unit = f
425+
| def test: Unit = inlined { println("") }
426+
""".stripMargin
427+
428+
checkBCode(source) { dir =>
429+
val clsIn = dir.lookupName("Foo.class", directory = false).input
430+
val clsNode = loadClassNode(clsIn)
431+
432+
val fun = getMethod(clsNode, "test")
433+
val instructions = instructionsFromMethod(fun)
434+
val expected = List(Invoke(INVOKESTATIC, "Foo", "f$1", "()V", false), Op(RETURN))
435+
assert(instructions == expected,
436+
"`inlined` was not properly inlined in `test`\n" + diffInstructions(instructions, expected))
437+
438+
}
439+
}
440+
441+
@Test def i6800b = {
442+
val source = """class Foo:
443+
| inline def printIfZero(x: Int): Unit = inline x match
444+
| case 0 => println("zero")
445+
| case _ => ()
446+
| def test: Unit = printIfZero(0)
447+
""".stripMargin
448+
449+
checkBCode(source) { dir =>
450+
val clsIn = dir.lookupName("Foo.class", directory = false).input
451+
val clsNode = loadClassNode(clsIn)
452+
453+
val fun = getMethod(clsNode, "test")
454+
val instructions = instructionsFromMethod(fun)
455+
val expected = List(
456+
Field(GETSTATIC, "scala/Predef$", "MODULE$", "Lscala/Predef$;"),
457+
Ldc(LDC, "zero"),
458+
Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false),
459+
Op(RETURN)
460+
)
461+
assert(instructions == expected,
462+
"`printIfZero` was not properly inlined in `test`\n" + diffInstructions(instructions, expected))
463+
}
464+
}
465+
466+
467+
@Test def i9246 = {
468+
val source = """class Foo:
469+
| inline def check(v:Double): Unit = if(v==0) throw new Exception()
470+
| inline def divide(v: Double, d: Double): Double = { check(d); v / d }
471+
| def test = divide(10,2)
472+
""".stripMargin
473+
474+
checkBCode(source) { dir =>
475+
val clsIn = dir.lookupName("Foo.class", directory = false).input
476+
val clsNode = loadClassNode(clsIn)
477+
478+
val fun = getMethod(clsNode, "test")
479+
val instructions = instructionsFromMethod(fun)
480+
val expected = List(Ldc(LDC, 5.0), Op(DRETURN))
481+
assert(instructions == expected,
482+
"`divide` was not properly inlined in `test`\n" + diffInstructions(instructions, expected))
483+
484+
}
485+
}
486+
421487
}

0 commit comments

Comments
 (0)