Skip to content

Commit d71a347

Browse files
Fix inline match on blocks with multiple statements (scala#20125)
Only the last expression of the block is considered as the inlined scrutinee. Otherwise we may not reduce as much as we should. We also need to make sure that side effects and bindings in the scrutinee are not duplicated. Inlined are converted into blocks to be able to apply the previous semantics without breaking the tree source files. Fixes scala#18151
2 parents adf089b + 530f775 commit d71a347

File tree

5 files changed

+154
-38
lines changed

5 files changed

+154
-38
lines changed

compiler/src/dotty/tools/dotc/inlines/Inliner.scala

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -860,46 +860,71 @@ class Inliner(val call: tpd.Tree)(using Context):
860860
case _ => sel.tpe
861861
}
862862
val selType = if (sel.isEmpty) wideSelType else selTyped(sel)
863-
reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match {
864-
case Some((caseBindings, rhs0)) =>
865-
// drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match)
866-
// note that any actually necessary casts will be reinserted by the typing pass below
867-
val rhs1 = rhs0 match {
868-
case Block(stats, t) if t.span.isSynthetic =>
869-
t match {
870-
case Typed(expr, _) =>
871-
Block(stats, expr)
872-
case TypeApply(sel@Select(expr, _), _) if sel.symbol.isTypeCast =>
873-
Block(stats, expr)
874-
case _ =>
875-
rhs0
863+
864+
/** Make an Inlined that has no bindings. */
865+
def flattenInlineBlock(tree: Tree): Tree = {
866+
def inlineBlock(call: Tree, stats: List[Tree], expr: Tree): Block =
867+
def inlinedTree(tree: Tree) = Inlined(call, Nil, tree).withSpan(tree.span)
868+
val stats1 = stats.map:
869+
case stat: ValDef => cpy.ValDef(stat)(rhs = inlinedTree(stat.rhs))
870+
case stat: DefDef => cpy.DefDef(stat)(rhs = inlinedTree(stat.rhs))
871+
case stat => inlinedTree(stat)
872+
cpy.Block(tree)(stats1, flattenInlineBlock(inlinedTree(expr)))
873+
874+
tree match
875+
case tree @ Inlined(call, bindings, expr) if !bindings.isEmpty =>
876+
inlineBlock(call, bindings, expr)
877+
case tree @ Inlined(call, Nil, Block(stats, expr)) =>
878+
inlineBlock(call, stats, expr)
879+
case _ =>
880+
tree
881+
}
882+
883+
def reduceInlineMatchExpr(sel: Tree): Tree = flattenInlineBlock(sel) match
884+
case Block(stats, expr) =>
885+
cpy.Block(sel)(stats, reduceInlineMatchExpr(expr))
886+
case _ =>
887+
reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match {
888+
case Some((caseBindings, rhs0)) =>
889+
// drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match)
890+
// note that any actually necessary casts will be reinserted by the typing pass below
891+
val rhs1 = rhs0 match {
892+
case Block(stats, t) if t.span.isSynthetic =>
893+
t match {
894+
case Typed(expr, _) =>
895+
Block(stats, expr)
896+
case TypeApply(sel@Select(expr, _), _) if sel.symbol.isTypeCast =>
897+
Block(stats, expr)
898+
case _ =>
899+
rhs0
900+
}
901+
case _ => rhs0
876902
}
877-
case _ => rhs0
878-
}
879-
val rhs2 = rhs1 match {
880-
case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr)
881-
case _ => constToLiteral(rhs1)
903+
val rhs2 = rhs1 match {
904+
case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr)
905+
case _ => constToLiteral(rhs1)
906+
}
907+
val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2)
908+
val rhs = seq(usedBindings, rhs3)
909+
inlining.println(i"""--- reduce:
910+
|$tree
911+
|--- to:
912+
|$rhs""")
913+
typedExpr(rhs, pt)
914+
case None =>
915+
def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard"
916+
def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}"
917+
val msg =
918+
if (tree.selector.isEmpty)
919+
em"""cannot reduce summonFrom with
920+
| patterns : ${tree.cases.map(patStr).mkString("\n ")}"""
921+
else
922+
em"""cannot reduce inline match with
923+
| scrutinee: $sel : ${selType}
924+
| patterns : ${tree.cases.map(patStr).mkString("\n ")}"""
925+
errorTree(tree, msg)
882926
}
883-
val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2)
884-
val rhs = seq(usedBindings, rhs3)
885-
inlining.println(i"""--- reduce:
886-
|$tree
887-
|--- to:
888-
|$rhs""")
889-
typedExpr(rhs, pt)
890-
case None =>
891-
def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard"
892-
def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}"
893-
val msg =
894-
if (tree.selector.isEmpty)
895-
em"""cannot reduce summonFrom with
896-
| patterns : ${tree.cases.map(patStr).mkString("\n ")}"""
897-
else
898-
em"""cannot reduce inline match with
899-
| scrutinee: $sel : ${selType}
900-
| patterns : ${tree.cases.map(patStr).mkString("\n ")}"""
901-
errorTree(tree, msg)
902-
}
927+
reduceInlineMatchExpr(sel)
903928
}
904929

905930
override def newLikeThis(nestingLevel: Int): Typer = new InlineTyper(initialErrorCount, nestingLevel)

compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,4 +785,36 @@ class InlineBytecodeTests extends DottyBytecodeTest {
785785
}
786786
}
787787

788+
@Test def inline_match_scrutinee_with_side_effect = {
789+
val source = """class Test:
790+
| inline def inlineTest(): Int =
791+
| inline {
792+
| println("scrutinee")
793+
| (1, 2)
794+
| } match
795+
| case (e1, e2) => e1 + e2
796+
|
797+
| def test: Int = inlineTest()
798+
""".stripMargin
799+
800+
checkBCode(source) { dir =>
801+
val clsIn = dir.lookupName("Test.class", directory = false).input
802+
val clsNode = loadClassNode(clsIn)
803+
804+
val fun = getMethod(clsNode, "test")
805+
val instructions = instructionsFromMethod(fun)
806+
val expected = List(
807+
Field(GETSTATIC, "scala/Predef$", "MODULE$", "Lscala/Predef$;"),
808+
Ldc(LDC, "scrutinee"),
809+
Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false),
810+
Op(ICONST_3),
811+
Op(IRETURN),
812+
)
813+
814+
assert(instructions == expected,
815+
"`i was not properly inlined in `test`\n" + diffInstructions(instructions, expected))
816+
817+
}
818+
}
819+
788820
}

tests/pos/i18151a.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
case class El[A](attr: String, child: String)
2+
3+
transparent inline def inlineTest(): String =
4+
inline {
5+
val el: El[Any] = El("1", "2")
6+
El[Any](el.attr, el.child)
7+
} match
8+
case El(attr, child) => attr + child
9+
10+
def test: Unit = inlineTest()

tests/pos/i18151b.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
case class El[A](val attr: String, val child: String)
2+
3+
transparent inline def tmplStr(inline t: El[Any]): String =
4+
inline t match
5+
case El(attr, child) => attr + child
6+
7+
def test: Unit = tmplStr {
8+
val el = El("1", "2")
9+
El[Any](el.attr, null)
10+
}

tests/pos/i18151c.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import scala.compiletime.*
2+
import scala.compiletime.ops.any.ToString
3+
4+
trait Attr
5+
case object EmptyAttr extends Attr
6+
transparent inline def attrStr(inline a: Attr): String = inline a match
7+
case EmptyAttr => ""
8+
transparent inline def attrStrHelper(inline a: Attr): String = inline a match
9+
case EmptyAttr => ""
10+
trait TmplNode
11+
case class El[T <: String & Singleton, A <: Attr, C <: Tmpl](val tag: T, val attr: A, val child: C)
12+
extends TmplNode
13+
case class Sib[L <: Tmpl, R <: Tmpl](left: L, right: R) extends TmplNode
14+
type TmplSingleton = String | Char | Int | Long | Float | Double | Boolean
15+
type Tmpl = TmplNode | Unit | (TmplSingleton & Singleton)
16+
transparent inline def tmplStr(inline t: Tmpl): String = inline t match
17+
case El(tag, attr, child) => inline attrStr(attr) match
18+
case "" => "<" + tag + ">" + tmplStr(child)
19+
case x => "<" + tag + " " + x + ">" + tmplStr(child)
20+
case Sib(left, right) => inline tmplStr(right) match
21+
case "" => tmplStr(left)
22+
case right => tmplStrHelper(left) + right
23+
case () => ""
24+
case s: (t & TmplSingleton) => constValue[ToString[t]]
25+
transparent inline def tmplStrHelper(inline t: Tmpl): String = inline t match
26+
case El(tag, attr, child) => inline (tmplStr(child), attrStr(attr)) match
27+
case ("", "") => "<" + tag + "/>"
28+
case (child, "") => "<" + tag + ">" + child + "</" + tag + ">"
29+
case ("", attr) => "<" + tag + " " + attr + "/>"
30+
case (child, attr) => "<" + tag + " " + attr + ">" + child + "</" + tag + ">"
31+
case Sib(left, right) => tmplStrHelper(left) + tmplStrHelper(right)
32+
case () => ""
33+
case s: (t & TmplSingleton) => constValue[ToString[t]]
34+
transparent inline def el(tag: String & Singleton): El[tag.type, EmptyAttr.type, Unit] =
35+
El(tag, EmptyAttr, ())
36+
extension [T <: String & Singleton, A <: Attr, C <: Tmpl](el: El[T, A, C])
37+
transparent inline def >>[C2 <: Tmpl](child: C2) = El(el.tag, el.attr, el.child ++ child)
38+
39+
extension [L <: Tmpl](left: L) transparent inline def ++[R <: Tmpl](right: R) = Sib(left, right)

0 commit comments

Comments
 (0)