Skip to content

Commit 06f311c

Browse files
authored
Merge pull request #8348 from fhackett/fhackett-fix-8306
Fix #8306: Ensure the inliner can elide effectively pure case class applications in various situations
2 parents abc9815 + ae336b6 commit 06f311c

File tree

3 files changed

+205
-10
lines changed

3 files changed

+205
-10
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
424424
computeParamBindings(tp.resultType, Nil, argss)
425425
case tp: MethodType =>
426426
if argss.isEmpty then
427-
ctx.error(i"mising arguments for inline method $inlinedMethod", call.sourcePos)
427+
ctx.error(i"missing arguments for inline method $inlinedMethod", call.sourcePos)
428428
false
429429
else
430430
tp.paramNames.lazyZip(tp.paramInfos).lazyZip(argss.head).foreach { (name, paramtp, arg) =>
@@ -477,6 +477,73 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
477477
tpe.cls.isContainedIn(inlinedMethod) ||
478478
tpe.cls.is(Package)
479479

480+
/** Very similar to TreeInfo.isPureExpr, but with the following inliner-only exceptions:
481+
* - synthetic case class apply methods, when the case class constructor is empty, are
482+
* elideable but not pure. Elsewhere, accessing the apply method might cause the initialization
483+
* of a containing object so they are merely idempotent.
484+
*/
485+
object isElideableExpr {
486+
def isStatElideable(tree: Tree)(implicit ctx: Context): Boolean = unsplice(tree) match {
487+
case EmptyTree
488+
| TypeDef(_, _)
489+
| Import(_, _)
490+
| DefDef(_, _, _, _, _) =>
491+
true
492+
case vdef @ ValDef(_, _, _) =>
493+
if (vdef.symbol.flags is Mutable) false else apply(vdef.rhs)
494+
case _ =>
495+
false
496+
}
497+
498+
def apply(tree: Tree): Boolean = unsplice(tree) match {
499+
case EmptyTree
500+
| This(_)
501+
| Super(_, _)
502+
| Literal(_) =>
503+
true
504+
case Ident(_) =>
505+
isPureRef(tree)
506+
case Select(qual, _) =>
507+
if (tree.symbol.is(Erased)) true
508+
else isPureRef(tree) && apply(qual)
509+
case New(_) | Closure(_, _, _) =>
510+
true
511+
case TypeApply(fn, _) =>
512+
if (fn.symbol.is(Erased) || fn.symbol == defn.InternalQuoted_typeQuote) true else apply(fn)
513+
case Apply(fn, args) =>
514+
def isKnownPureOp(sym: Symbol) =
515+
sym.owner.isPrimitiveValueClass
516+
|| sym.owner == defn.StringClass
517+
|| defn.pureMethods.contains(sym)
518+
val isCaseClassApply = {
519+
val cls = tree.tpe.classSymbol
520+
val meth = fn.symbol
521+
meth.name == nme.apply &&
522+
meth.flags.is(Synthetic) &&
523+
meth.owner.linkedClass.is(Case) &&
524+
cls.isNoInitsClass
525+
}
526+
if (tree.tpe.isInstanceOf[ConstantType] && isKnownPureOp(tree.symbol) // A constant expression with pure arguments is pure.
527+
|| (fn.symbol.isStableMember && !fn.symbol.is(Lazy))
528+
|| fn.symbol.isPrimaryConstructor && fn.symbol.owner.isNoInitsClass) // TODO: include in isStable?
529+
apply(fn) && args.forall(apply)
530+
else if (isCaseClassApply)
531+
args.forall(apply)
532+
else if (fn.symbol.is(Erased)) true
533+
else false
534+
case Typed(expr, _) =>
535+
apply(expr)
536+
case Block(stats, expr) =>
537+
apply(expr) && stats.forall(isStatElideable)
538+
case Inlined(_, bindings, expr) =>
539+
apply(expr) && bindings.forall(isStatElideable)
540+
case NamedArg(_, expr) =>
541+
apply(expr)
542+
case _ =>
543+
false
544+
}
545+
}
546+
480547
/** Populate `thisProxy` and `paramProxy` as follows:
481548
*
482549
* 1a. If given type refers to a static this, thisProxy binds it to corresponding global reference,
@@ -739,13 +806,16 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
739806
Some(meth.owner.linkedClass, args, Nil, false)
740807
else None
741808
}
809+
case Typed(inner, _) =>
810+
// drop the ascribed tpt. We only need it if we can't find a NewInstance
811+
unapply(inner)
742812
case Ident(_) =>
743813
val binding = tree.symbol.defTree
744814
for ((cls, reduced, prefix, precomputed) <- unapply(binding))
745815
yield (cls, reduced, prefix, precomputed || binding.isInstanceOf[ValDef])
746816
case Inlined(_, bindings, expansion) =>
747817
unapplyLet(bindings, expansion)
748-
case Block(stats, expr) if isPureExpr(tree) =>
818+
case Block(stats, expr) if isElideableExpr(tree) =>
749819
unapplyLet(stats, expr)
750820
case _ =>
751821
None
@@ -778,13 +848,13 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
778848
.reporting(i"projecting $tree -> $result", inlining)
779849
val arg = args(idx)
780850
if (precomputed)
781-
if (isPureExpr(arg)) finish(arg)
851+
if (isElideableExpr(arg)) finish(arg)
782852
else tree // nothing we can do here, projection would duplicate side effect
783853
else {
784854
// newInstance is evaluated in place, need to reflect side effects of
785855
// arguments in the order they were written originally
786856
def collectImpure(from: Int, end: Int) =
787-
(from until end).filterNot(i => isPureExpr(args(i))).toList.map(args)
857+
(from until end).filterNot(i => isElideableExpr(args(i))).toList.map(args)
788858
val leading = collectImpure(0, idx)
789859
val trailing = collectImpure(idx + 1, args.length)
790860
val argInPlace =
@@ -1041,7 +1111,9 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
10411111
case (Nil, Nil) => true
10421112
case (pat :: pats1, selector :: selectors1) =>
10431113
val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenTermRefExpr).asTerm
1044-
caseBindingMap += ((NoSymbol, ValDef(elem, constToLiteral(selector)).withSpan(elem.span)))
1114+
val rhs = constToLiteral(selector)
1115+
elem.defTree = rhs
1116+
caseBindingMap += ((NoSymbol, ValDef(elem, rhs).withSpan(elem.span)))
10451117
reducePattern(caseBindingMap, elem.termRef, pat) &&
10461118
reduceSubPatterns(pats1, selectors1)
10471119
case _ => false
@@ -1143,9 +1215,14 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
11431215
override def typedSelect(tree: untpd.Select, pt: Type)(using Context): Tree = {
11441216
assert(tree.hasType, tree)
11451217
val qual1 = typed(tree.qualifier, selectionProto(tree.name, pt, this))
1146-
val res = constToLiteral(untpd.cpy.Select(tree)(qual1, tree.name).withType(tree.typeOpt))
1147-
ensureAccessible(res.tpe, tree.qualifier.isInstanceOf[untpd.Super], tree.sourcePos)
1148-
checkStaging(res)
1218+
val resNoReduce = untpd.cpy.Select(tree)(qual1, tree.name).withType(tree.typeOpt)
1219+
val resMaybeReduced = constToLiteral(reducer.reduceProjection(resNoReduce))
1220+
if (resNoReduce ne resMaybeReduced)
1221+
typed(resMaybeReduced, pt) // redo typecheck if reduction changed something
1222+
else
1223+
val res = resMaybeReduced
1224+
ensureAccessible(res.tpe, tree.qualifier.isInstanceOf[untpd.Super], tree.sourcePos)
1225+
checkStaging(res)
11491226
}
11501227

11511228
private def checkStaging(tree: Tree): tree.type =
@@ -1264,8 +1341,8 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
12641341
val bindingOfSym = newMutableSymbolMap[MemberDef]
12651342

12661343
def isInlineable(binding: MemberDef) = binding match {
1267-
case ddef @ DefDef(_, Nil, Nil, _, _) => isPureExpr(ddef.rhs)
1268-
case vdef @ ValDef(_, _, _) => isPureExpr(vdef.rhs)
1344+
case ddef @ DefDef(_, Nil, Nil, _, _) => isElideableExpr(ddef.rhs)
1345+
case vdef @ ValDef(_, _, _) => isElideableExpr(vdef.rhs)
12691346
case _ => false
12701347
}
12711348
for (binding <- bindings if isInlineable(binding)) {

tests/run/i8306.check

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
compile-time: 3
2+
run-time: 3
3+
compile-time: 3
4+
run-time: 3
5+
compile-time: 3
6+
run-time: 3
7+
compile-time: 3
8+
run-time: 3
9+
compile-time: {
10+
val $elem9: A = Test.a
11+
val $elem10: Int = $elem9.i
12+
val i: Int = $elem10
13+
i:Int
14+
}
15+
run-time: 3
16+
compile-time: 3
17+
run-time: 3
18+
compile-time: 3
19+
run-time: 3
20+
compile-time: 3
21+
run-time: 3
22+
compile-time: 3
23+
run-time: 3
24+
compile-time: 3
25+
run-time: 3

tests/run/i8306.scala

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import scala.compiletime._
2+
3+
case class A(i: Int)
4+
case class B(a: A)
5+
case class C[T](t: T)
6+
7+
trait Test8 {
8+
inline def test8: Int =
9+
inline A(3) match {
10+
case A(i) => i
11+
}
12+
}
13+
14+
object Test extends Test8 {
15+
16+
inline def test1: Int =
17+
inline A(3) match {
18+
case A(i) => i
19+
}
20+
21+
inline def test2: Int =
22+
inline (A(3) : A) match {
23+
case A(i) => i
24+
}
25+
26+
inline def test3: Int =
27+
inline B(A(3)) match {
28+
case B(A(i)) => i
29+
}
30+
31+
inline def test4: Int =
32+
A(3).i
33+
34+
val a = A(3)
35+
inline def test5: Int =
36+
inline new B(a) match {
37+
case B(A(i)) => i
38+
}
39+
40+
inline def test6: Int =
41+
inline B(A(3)).a match {
42+
case A(i) => i
43+
}
44+
45+
inline def test7: Int =
46+
inline new A(3) match {
47+
case A(i) => i
48+
}
49+
50+
inline def test9: Int =
51+
B(A(3)).a.i
52+
53+
inline def test10: Int =
54+
inline C(3) match {
55+
case C(t) => t
56+
}
57+
58+
def main(argv: Array[String]): Unit = {
59+
println(code"compile-time: ${test1}")
60+
println(s"run-time: ${test1}")
61+
62+
println(code"compile-time: ${test2}")
63+
println(s"run-time: ${test2}")
64+
65+
println(code"compile-time: ${test3}")
66+
println(s"run-time: ${test3}")
67+
68+
println(code"compile-time: ${test4}")
69+
println(s"run-time: ${test4}")
70+
71+
// this is the only test that should not be possible to fully inline,
72+
// because it references a non-inline value
73+
println(code"compile-time: ${test5}")
74+
println(s"run-time: ${test5}")
75+
76+
println(code"compile-time: ${test6}")
77+
println(s"run-time: ${test6}")
78+
79+
println(code"compile-time: ${test7}")
80+
println(s"run-time: ${test7}")
81+
82+
println(code"compile-time: ${test8}")
83+
println(s"run-time: ${test8}")
84+
85+
println(code"compile-time: ${test9}")
86+
println(s"run-time: ${test9}")
87+
88+
// with type parameter
89+
println(code"compile-time: ${test10}")
90+
println(s"run-time: ${test10}")
91+
}
92+
}
93+

0 commit comments

Comments
 (0)