Skip to content

Commit 0ee3281

Browse files
committed
Implement access checking for enum cases
Previously, enum cases moved from the enum to its companion object could "accidentally" refer to definitions defined in the object, but inaccessible from the enum. We now check that no such accesses occur.
1 parent fa0a25f commit 0ee3281

File tree

9 files changed

+234
-8
lines changed

9 files changed

+234
-8
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ object DesugarEnums {
6868

6969
/** Add implied flags to an enum class or an enum case */
7070
def addEnumFlags(cdef: TypeDef)(implicit ctx: Context) =
71-
if (cdef.mods.hasMod[Mod.Enum]) cdef.withFlags(cdef.mods.flags | Abstract | Sealed)
72-
else if (isEnumCase(cdef)) cdef.withFlags(cdef.mods.flags | Final)
71+
if (cdef.mods.hasMod[Mod.Enum]) cdef.withMods(cdef.mods.withFlags(cdef.mods.flags | Abstract | Sealed))
72+
else if (isEnumCase(cdef)) cdef.withMods(cdef.mods.withFlags(cdef.mods.flags | Final))
7373
else cdef
7474

7575
private def valuesDot(name: String) = Select(Ident(nme.DOLLAR_VALUES), name.toTermName)

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import Decorators._
2828
import Uniques._
2929
import ErrorReporting.{err, errorType}
3030
import config.Printers.typr
31+
import NameKinds.DefaultGetterName
3132

3233
import collection.mutable
3334
import SymDenotations.NoCompleter
@@ -741,9 +742,100 @@ trait Checking {
741742
tp.foreachPart(check, stopAtStatic = true)
742743
tp
743744
}
745+
746+
/** Check that all non-synthetic references of the form `<ident>` or
747+
* `this.<ident>` in `tree` that refer to a member of `badOwner` are
748+
* `allowed`.
749+
*/
750+
def checkRefsLegal(tree: tpd.Tree, badOwner: Symbol, allowed: (Name, Symbol) => Boolean, where: String)(implicit ctx: Context): Unit = {
751+
tree.foreachSubTree { tree =>
752+
tree match {
753+
case Ident(_) | Select(This(_), _) if tree.pos.isSourceDerived =>
754+
val sym = tree.symbol
755+
if (sym.maybeOwner == badOwner && !allowed(tree.asInstanceOf[RefTree].name, sym))
756+
ctx.error(i"illegal reference to $sym from $where: $tree // ${tree.toString}", tree.pos)
757+
case _ =>
758+
}
759+
}
760+
}
761+
762+
/** Check that all case classes that extend `scala.Enum` are `enum` cases */
763+
def checkEnum(cdef: untpd.TypeDef, cls: Symbol)(implicit ctx: Context): Unit = {
764+
import untpd.modsDeco
765+
def isEnumAnonCls =
766+
cls.isAnonymousClass &&
767+
cls.owner.isTerm &&
768+
(cls.owner.flagsUNSAFE.is(Case) || cls.owner.name == nme.DOLLAR_NEW)
769+
if (!cdef.mods.hasMod[untpd.Mod.EnumCase] && !isEnumAnonCls)
770+
ctx.error(em"normal case $cls in ${cls.owner} cannot extend an enum", cdef.pos)
771+
}
772+
773+
/** Check that all references coming from enum cases in an enum companion object
774+
* are legal.
775+
* @param cdef the enum companion object class
776+
* @param enumCtx the context immediately enclosing the corresponding enum
777+
*/
778+
private def checkEnumCaseRefsLegal(cdef: TypeDef, enumCtx: Context)(implicit ctx: Context): Unit = {
779+
def check(tree: Tree) = {
780+
// allow access to `sym` if a typedIdent just outside the enclosing enum
781+
// would have produced the same symbol without errors
782+
def allowAccess(name: Name, sym: Symbol): Boolean = {
783+
val testCtx = enumCtx.fresh.setNewTyperState()
784+
val ref = ctx.typer.typedIdent(untpd.Ident(name), WildcardType)(testCtx)
785+
ref.symbol == sym && !testCtx.reporter.hasErrors
786+
}
787+
checkRefsLegal(tree, cdef.symbol, allowAccess, "enum case")
788+
}
789+
cdef.rhs match {
790+
case impl: Template =>
791+
for (stat <- impl.body)
792+
if (stat.symbol.is(Case))
793+
stat match {
794+
case TypeDef(_, Template(DefDef(_, tparams, vparamss, _, _), parents, _, _)) =>
795+
tparams.foreach(check)
796+
vparamss.foreach(_.foreach(check))
797+
parents.foreach(check)
798+
case vdef: ValDef =>
799+
vdef.rhs match {
800+
case Block((clsDef @ TypeDef(_, impl: Template)) :: Nil, _)
801+
if clsDef.symbol.isAnonymousClass =>
802+
impl.parents.foreach(check)
803+
case _ =>
804+
}
805+
case _ =>
806+
}
807+
else if (stat.symbol.is(Module) && stat.symbol.linkedClass.is(Case))
808+
stat match {
809+
case TypeDef(_, impl: Template) =>
810+
for ((defaultGetter @
811+
DefDef(DefaultGetterName(nme.CONSTRUCTOR, _), _, _, _, _)) <- impl.body)
812+
check(defaultGetter.rhs)
813+
case _ =>
814+
}
815+
case _ =>
816+
}
817+
}
818+
819+
/** Check all enum cases in all enum companions in `stats` for legal accesses.
820+
* @param enumContexts a map from`enum` symbols to the contexts enclosing their definitions
821+
*/
822+
def checkEnumCompanions(stats: List[Tree], enumContexts: collection.Map[Symbol, Context])(implicit ctx: Context): List[Tree] = {
823+
for (stat @ TypeDef(_, _) <- stats)
824+
if (stat.symbol.is(Module))
825+
for (enumContext <- enumContexts.get(stat.symbol.linkedClass))
826+
checkEnumCaseRefsLegal(stat, enumContext)
827+
stats
828+
}
829+
}
830+
831+
trait ReChecking extends Checking {
832+
import tpd._
833+
override def checkEnum(cdef: untpd.TypeDef, cls: Symbol)(implicit ctx: Context): Unit = ()
834+
override def checkRefsLegal(tree: tpd.Tree, badOwner: Symbol, allowed: (Name, Symbol) => Boolean, where: String)(implicit ctx: Context): Unit = ()
835+
override def checkEnumCompanions(stats: List[Tree], enumContexts: collection.Map[Symbol, Context])(implicit ctx: Context): List[Tree] = stats
744836
}
745837

746-
trait NoChecking extends Checking {
838+
trait NoChecking extends ReChecking {
747839
import tpd._
748840
override def checkNonCyclic(sym: Symbol, info: TypeBounds, reportErrors: Boolean)(implicit ctx: Context): Type = info
749841
override def checkValue(tree: Tree, proto: Type)(implicit ctx: Context): tree.type = tree

compiler/src/dotty/tools/dotc/typer/ReTyper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import config.Printers.typr
2222
*
2323
* Otherwise, everything is as in Typer.
2424
*/
25-
class ReTyper extends Typer {
25+
class ReTyper extends Typer with ReChecking {
2626
import tpd._
2727

2828
private def assertTyped(tree: untpd.Tree)(implicit ctx: Context): Unit =

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,7 @@ class Typer extends Namer
14061406
if (sym.isInlineMethod) Inliner.registerInlineInfo(sym, _ => rhs1)
14071407

14081408
assignType(cpy.DefDef(ddef)(name, tparams1, vparamss1, tpt1, rhs1), sym)
1409-
//todo: make sure dependent method types do not depend on implicits or by-name params
1409+
//todo: make sure dependent method types do not depend on implicits or by-name params
14101410
}
14111411

14121412
def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(implicit ctx: Context): Tree = track("typedTypeDef") {
@@ -1506,6 +1506,7 @@ class Typer extends Namer
15061506
checkVariance(impl1)
15071507
if (!cls.is(AbstractOrTrait) && !ctx.isAfterTyper)
15081508
checkRealizableBounds(cls, cdef.namePos)
1509+
if (cls.is(Case) && cls.derivesFrom(defn.EnumClass)) checkEnum(cdef, cls)
15091510
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)
15101511
if (ctx.phase.isTyper && cdef1.tpe.derivesFrom(defn.DynamicClass) && !ctx.dynamicsEnabled) {
15111512
val isRequired = parents1.exists(_.tpe.isRef(defn.DynamicClass))
@@ -1813,6 +1814,8 @@ class Typer extends Namer
18131814

18141815
def typedStats(stats: List[untpd.Tree], exprOwner: Symbol)(implicit ctx: Context): List[tpd.Tree] = {
18151816
val buf = new mutable.ListBuffer[Tree]
1817+
val enumContexts = new mutable.HashMap[Symbol, Context]
1818+
// A map from `enum` symbols to the contexts enclosing their definitions
18161819
@tailrec def traverse(stats: List[untpd.Tree])(implicit ctx: Context): List[Tree] = stats match {
18171820
case (imp: untpd.Import) :: rest =>
18181821
val imp1 = typed(imp)
@@ -1827,6 +1830,12 @@ class Typer extends Namer
18271830
case mdef1: DefDef if Inliner.hasBodyToInline(mdef1.symbol) =>
18281831
buf ++= inlineExpansion(mdef1)
18291832
case mdef1 =>
1833+
import untpd.modsDeco
1834+
mdef match {
1835+
case mdef: untpd.TypeDef if mdef.mods.hasMod[untpd.Mod.Enum] =>
1836+
enumContexts(mdef1.symbol) = ctx
1837+
case _ =>
1838+
}
18301839
buf += mdef1
18311840
}
18321841
traverse(rest)
@@ -1846,7 +1855,7 @@ class Typer extends Namer
18461855
val exprOwnerOpt = if (exprOwner == ctx.owner) None else Some(exprOwner)
18471856
ctx.withProperty(ExprOwner, exprOwnerOpt)
18481857
}
1849-
traverse(stats)(localCtx)
1858+
checkEnumCompanions(traverse(stats)(localCtx), enumContexts)
18501859
}
18511860

18521861
/** Given an inline method `mdef`, the method rewritten so that its body

tests/neg/enums.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ enum E3[-T <: Ordered[T]] {
1717
case C // error: cannot determine type argument
1818
}
1919

20+
enum E4 {
21+
case C
22+
}
23+
24+
case class C4() extends E4 // error: cannot extend enum
25+
case object O4 extends E4 // error: cannot extend enum
26+
2027
enum Option[+T] {
2128
case Some(x: T)
2229
case None

tests/neg/enumsAccess.scala

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package enums
2+
3+
object test1 {
4+
5+
enum E4 {
6+
case C1(x: INT) // error: illegal reference
7+
case C2(x: Int = defaultX) // error: illegal reference
8+
case C3[T <: INT] // error: illegal reference
9+
}
10+
11+
object E4 {
12+
type INT = Integer
13+
val defaultX = 2
14+
}
15+
}
16+
17+
object test2 {
18+
import E5._
19+
object E5 {
20+
type INT = Integer
21+
val defaultX = 2
22+
}
23+
24+
enum E5 {
25+
case C1(x: INT) // ok
26+
case C2(x: Int = defaultX) // ok
27+
case C3[T <: INT] // ok
28+
}
29+
}
30+
31+
object test3 {
32+
object E5 {
33+
type INT = Integer
34+
val defaultX = 2
35+
}
36+
37+
import E5._
38+
39+
enum E5 {
40+
case C1(x: INT) // ok
41+
case C2(x: Int = defaultX)// ok
42+
case C3[T <: INT] // ok
43+
}
44+
}
45+
46+
object test4 {
47+
48+
enum E5 {
49+
case C1(x: INT) // error: illegal reference
50+
case C2(x: Int = defaultX) // error: illegal reference
51+
case C3[T <: INT] // error: illegal reference
52+
}
53+
54+
import E5._
55+
56+
object E5 {
57+
type INT = Integer
58+
val defaultX = 2
59+
}
60+
}
61+
62+
object test5 {
63+
enum E5[T](x: T) {
64+
case C3() extends E5[INT](defaultX)// error: illegal reference // error: illegal reference
65+
case C4 extends E5[INT](defaultX) // error: illegal reference // error: illegal reference
66+
case C5 extends E5[E5[_]](E5.this) // error: type mismatch
67+
}
68+
69+
object E5 {
70+
type INT = Integer
71+
val defaultX = 2
72+
}
73+
}
74+
75+
object test6 {
76+
import E5._
77+
enum E5[T](x: T) {
78+
case C3() extends E5[INT](defaultX) // ok
79+
case C4 extends E5[INT](defaultX) // ok
80+
}
81+
82+
object E5 {
83+
type INT = Integer
84+
val defaultX = 2
85+
}
86+
}

tests/pos/enum-List-control.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
abstract sealed class List[T] extends Enum
22
object List {
3-
final case class Cons[T](x: T, xs: List[T]) extends List[T] {
3+
final class Cons[T](x: T, xs: List[T]) extends List[T] {
44
def enumTag = 0
55
}
6-
final case class Nil[T]() extends List[T] {
6+
object Cons {
7+
def apply[T](x: T, xs: List[T]): List[T] = new Cons(x, xs)
8+
}
9+
final class Nil[T]() extends List[T] {
710
def enumTag = 1
811
}
12+
object Nil {
13+
def apply[T](): List[T] = new Nil()
14+
}
915
}
1016
object Test {
1117
import List._

tests/run/enum-Color.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
enum Color {
22
case Red, Green, Blue
3+
class Color // Just to throw a spanner in the works
34
}
45

56
object Test {

tests/run/generic/Enum.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package generic
2+
3+
trait Enum {
4+
def enumTag: Int
5+
}
6+
7+
object runtime {
8+
class EnumValues[E <: Enum] {
9+
private[this] var myMap: Map[Int, E] = Map()
10+
private[this] var fromNameCache: Map[String, E] = null
11+
12+
def register(v: E) = {
13+
require(!myMap.contains(v.enumTag))
14+
myMap = myMap.updated(v.enumTag, v)
15+
fromNameCache = null
16+
}
17+
18+
def fromInt: Map[Int, E] = myMap
19+
def fromName: Map[String, E] = {
20+
if (fromNameCache == null) fromNameCache = myMap.values.map(v => v.toString -> v).toMap
21+
fromNameCache
22+
}
23+
def values: Iterable[E] = myMap.values
24+
}
25+
}

0 commit comments

Comments
 (0)