Skip to content

Commit a475e47

Browse files
committed
Fix #8033: Eliminate unused outer accessors
1 parent 4cb1cda commit a475e47

File tree

6 files changed

+167
-4
lines changed

6 files changed

+167
-4
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ class Compiler {
109109
new GetClass) :: // Rewrites getClass calls on primitive types.
110110
List(new LinkScala2Impls, // Redirect calls to trait methods defined by Scala 2.x, so that they now go to
111111
new LambdaLift, // Lifts out nested functions to class scope, storing free variables in environments
112-
// Note: in this mini-phase block scopes are incorrect. No phases that rely on scopes should be here
113-
new ElimStaticThis) :: // Replace `this` references to static objects by global identifiers
112+
// Note: in this mini-phase block scopes are incorrect. No phases that rely on scopes should be here
113+
new ElimStaticThis, // Replace `this` references to static objects by global identifiers
114+
new DropOuterAccessors) :: // Drop unused outer accessors
114115
List(new Flatten, // Lift all inner classes to package scope
115116
new RenameLifted, // Renames lifted classes to local numbering scheme
116117
new TransformWildcards, // Replace wildcards with default values

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,9 +1079,15 @@ object SymDenotations {
10791079
final def lexicallyEnclosingClass(implicit ctx: Context): Symbol =
10801080
if (!exists || isClass) symbol else owner.lexicallyEnclosingClass
10811081

1082+
/** A class is extensible if it is not final, nor a module class,
1083+
* nor an anonymous class.
1084+
*/
1085+
final def isExtensibleClass(using Context): Boolean =
1086+
isClass && !isOneOf(FinalOrModuleClass) && !isAnonymousClass
1087+
10821088
/** A symbol is effectively final if it cannot be overridden in a subclass */
10831089
final def isEffectivelyFinal(implicit ctx: Context): Boolean =
1084-
isOneOf(EffectivelyFinalFlags) || !owner.isClass || owner.isOneOf(FinalOrModuleClass) || owner.isAnonymousClass
1090+
isOneOf(EffectivelyFinalFlags) || !owner.isExtensibleClass
10851091

10861092
/** A class is effectively sealed if has the `final` or `sealed` modifier, or it
10871093
* is defined in Scala 3 and is neither abstract nor open.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import core._
5+
import MegaPhase.MiniPhase
6+
import dotty.tools.dotc.core.Contexts.Context
7+
import ast._
8+
import Trees._
9+
import Flags._
10+
import Symbols._
11+
import Decorators._
12+
import DenotTransformers._
13+
import StdNames.nme
14+
import collection.mutable
15+
16+
object DropOuterAccessors:
17+
val name: String = "dropOuterAccessors"
18+
19+
/** Drops outer accessors of final classes that are unused */
20+
class DropOuterAccessors extends MiniPhase with IdentityDenotTransformer:
21+
thisPhase =>
22+
import tpd._
23+
24+
override def phaseName: String = DropOuterAccessors.name
25+
26+
override def runsAfter: Set[String] = Set(LambdaLift.name)
27+
// LambdaLift can create outer paths. These need to be known in this phase
28+
29+
override def changesMembers: Boolean = true // the phase drops outer accessors
30+
31+
def (sym: Symbol).isOuterParamAccessor(using Context) =
32+
sym.is(ParamAccessor) && sym.name == nme.OUTER
33+
34+
private def mightBeDropped(sym: Symbol)(using Context) =
35+
(sym.is(OuterAccessor) || sym.isOuterParamAccessor)
36+
&& !sym.owner.isExtensibleClass
37+
38+
/** The number of times an outer accessor that might be dropped is accessed */
39+
private val accessCount = new mutable.HashMap[Symbol, Int]:
40+
override def default(key: Symbol): Int = 0
41+
42+
private def markAccessed(tree: RefTree)(implicit ctx: Context): Tree =
43+
val sym = tree.symbol
44+
if mightBeDropped(sym) then accessCount(sym) += 1
45+
tree
46+
47+
override def transformIdent(tree: Ident)(using Context): Tree =
48+
markAccessed(tree)
49+
50+
override def transformSelect(tree: Select)(using Context): Tree =
51+
markAccessed(tree)
52+
53+
override def transformTemplate(impl: Template)(using ctx: Context): Tree =
54+
55+
def dropOuterAccessor(stat: Tree): Boolean = stat match
56+
case stat: DefDef
57+
if stat.symbol.is(OuterAccessor)
58+
&& mightBeDropped(stat.symbol)
59+
&& accessCount(stat.symbol) == 0 =>
60+
assert(stat.rhs.isInstanceOf[RefTree], stat)
61+
assert(accessCount(stat.rhs.symbol) > 0)
62+
accessCount(stat.rhs.symbol) -= 1
63+
stat.symbol.dropAfter(thisPhase)
64+
true
65+
case _ =>
66+
false
67+
68+
val droppedParamAccessors = mutable.Set[Symbol]()
69+
70+
def dropOuterParamAccessor(stat: Tree): Boolean = stat match
71+
case stat: ValDef
72+
if stat.symbol.isOuterParamAccessor
73+
&& mightBeDropped(stat.symbol)
74+
&& accessCount(stat.symbol) == 1 =>
75+
droppedParamAccessors += stat.symbol
76+
stat.symbol.dropAfter(thisPhase)
77+
true
78+
case _ =>
79+
false
80+
81+
def dropOuterInit(stat: Tree): Boolean = stat match
82+
case Assign(lhs, rhs) => droppedParamAccessors.remove(lhs.symbol)
83+
case _ => false
84+
85+
val body1 = impl.body
86+
.filterNot(dropOuterAccessor)
87+
.filterNot(dropOuterParamAccessor)
88+
val constr1 =
89+
if droppedParamAccessors.isEmpty then impl.constr
90+
else cpy.DefDef(impl.constr)(
91+
rhs = impl.constr.rhs match {
92+
case rhs @ Block(inits, expr) =>
93+
cpy.Block(rhs)(inits.filterNot(dropOuterInit), expr)
94+
})
95+
assert(droppedParamAccessors.isEmpty,
96+
i"""Failed to eliminate: $droppedParamAccessors
97+
when dropping outer accessors for ${ctx.owner} with
98+
$impl""")
99+
cpy.Template(impl)(constr = constr1, body = body1)
100+
end transformTemplate

compiler/src/dotty/tools/dotc/transform/LambdaLift.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ object LambdaLift {
2323
import ast.tpd._
2424
private class NoPath extends Exception
2525

26+
val name: String = "lambdaLift"
27+
2628
/** The core lambda lift functionality. */
2729
class Lifter(thisPhase: MiniPhase with DenotTransformer)(implicit ctx: Context) {
2830

@@ -500,7 +502,7 @@ class LambdaLift extends MiniPhase with IdentityDenotTransformer { thisPhase =>
500502
import ast.tpd._
501503

502504
/** the following two members override abstract members in Transform */
503-
val phaseName: String = "lambdaLift"
505+
val phaseName: String = LambdaLift.name
504506

505507
override def relaxedTypingInGroup: Boolean = true
506508
// Because it adds free vars as additional proxy parameters

tests/run/i8033.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
trait Okay extends Serializable {
2+
def okay: Okay
3+
}
4+
5+
class Foo {
6+
def okay1: Okay = new Okay() {
7+
val okay: Okay = this
8+
override def toString = "okay1"
9+
}
10+
def okay2: Okay = new Okay {
11+
val okay: Okay = okay1
12+
override def toString = "okay2"
13+
}
14+
}
15+
16+
object Test {
17+
def main(args: Array[String]): Unit = {
18+
val foo = new Foo
19+
assert(roundTrip(foo.okay1).toString == "okay1")
20+
assert(roundTrip(foo.okay2).toString == "okay2")
21+
}
22+
23+
def roundTrip[A](a: A): A = {
24+
import java.io._
25+
26+
val aos = new ByteArrayOutputStream()
27+
val oos = new ObjectOutputStream(aos)
28+
oos.writeObject(a)
29+
oos.close()
30+
val ais = new ByteArrayInputStream(aos.toByteArray())
31+
val ois: ObjectInputStream = new ObjectInputStream(ais)
32+
val newA = ois.readObject()
33+
ois.close()
34+
newA.asInstanceOf[A]
35+
}
36+
}

tests/run/outer-accessors.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
class A:
2+
val a = 2
3+
4+
class B:
5+
val b = 3
6+
7+
trait T:
8+
def t = a + b
9+
10+
val bb = B()
11+
12+
class C extends bb.T:
13+
def result = a + t
14+
15+
@main def Test =
16+
val a = A()
17+
val c = a.C()
18+
assert(c.result == 7)

0 commit comments

Comments
 (0)