Skip to content

Add enum exhaustivity checking #2197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ object Printers {
val cyclicErrors: Printer = noPrinter
val pickling: Printer = noPrinter
val inlining: Printer = noPrinter
val exhaustivity: Printer = noPrinter
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ class IsInstanceOfEvaluator extends MiniPhaseTransform { thisTransformer =>
(scTrait && selTrait)

val inMatch = s.qualifier.symbol is Case
// FIXME: This will misclassify case objects! We need to find another way to characterize
// isInstanceOfs generated by matches.
// Probably the most robust way is to use another symbol for the isInstanceOf method.

if (valueClassesOrAny) tree
else if (knownStatically)
Expand Down
12 changes: 9 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,14 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisTran
private def transformAnnot(annot: Annotation)(implicit ctx: Context): Annotation =
annot.derivedAnnotation(transformAnnot(annot.tree))

private def registerChild(sym: Symbol, tp: Type)(implicit ctx: Context) = {
val cls = tp.classSymbol
if (cls.is(Sealed)) cls.addAnnotation(Annotation.makeChild(sym))
}

private def transformMemberDef(tree: MemberDef)(implicit ctx: Context): Unit = {
val sym = tree.symbol
if (sym.is(CaseVal, butNot = Method | Module)) registerChild(sym, sym.info)
sym.transformAnnotations(transformAnnot)
}

Expand Down Expand Up @@ -227,9 +233,9 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisTran

// Add Child annotation to sealed parents unless current class is anonymous
if (!sym.isAnonymousClass) // ignore anonymous class
for (parent <- sym.asClass.classInfo.classParents) {
val pclazz = parent.classSymbol
if (pclazz.is(Sealed)) pclazz.addAnnotation(Annotation.makeChild(sym))
sym.asClass.classInfo.classParents.foreach { parent =>
val sym2 = if (sym.is(Module)) sym.sourceModule else sym
registerChild(sym2, parent)
}

tree
Expand Down
161 changes: 64 additions & 97 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import core.StdNames._
import core.NameOps._
import core.Constants._
import reporting.diagnostic.messages._
import config.Printers.{ exhaustivity => debug }

/** Space logic for checking exhaustivity and unreachability of pattern matching
*
Expand All @@ -28,8 +29,6 @@ import reporting.diagnostic.messages._
* 3. A union of spaces `S1 | S2 | ...` is a space
* 4. For a case class Kon(x1: T1, x2: T2, .., xn: Tn), if S1, S2, ..., Sn
* are spaces, then `Kon(S1, S2, ..., Sn)` is a space.
* 5. A constant `Const(value, T)` is a point in space
* 6. A stable identifier `Var(sym, T)` is a space
*
* For the problem of exhaustivity check, its formulation in terms of space is as follows:
*
Expand Down Expand Up @@ -67,15 +66,6 @@ case class Kon(tp: Type, params: List[Space]) extends Space
/** Union of spaces */
case class Or(spaces: List[Space]) extends Space

/** Point in space */
sealed trait Point extends Space

/** Point representing variables(stable identifier) in patterns */
case class Var(sym: Symbol, tp: Type) extends Point

/** Point representing literal constants in patterns */
case class Const(value: Constant, tp: Type) extends Point

/** abstract space logic */
trait SpaceLogic {
/** Is `tp1` a subtype of `tp2`? */
Expand All @@ -97,6 +87,9 @@ trait SpaceLogic {
/** Get components of decomposable types */
def decompose(tp: Type): List[Space]

/** Display space in string format */
def show(sp: Space): String

/** Simplify space using the laws, there's no nested union after simplify */
def simplify(space: Space): Space = space match {
case Kon(tp, spaces) =>
Expand Down Expand Up @@ -137,7 +130,7 @@ trait SpaceLogic {
def tryDecompose1(tp: Type) = canDecompose(tp) && isSubspace(Or(decompose(tp)), b)
def tryDecompose2(tp: Type) = canDecompose(tp) && isSubspace(a, Or(decompose(tp)))

(a, b) match {
val res = (a, b) match {
case (Empty, _) => true
case (_, Empty) => false
case (Or(ss), _) => ss.forall(isSubspace(_, b))
Expand All @@ -157,25 +150,19 @@ trait SpaceLogic {
simplify(minus(a, b)) == Empty
case (Kon(tp1, ss1), Kon(tp2, ss2)) =>
isEqualType(tp1, tp2) && ss1.zip(ss2).forall((isSubspace _).tupled)
case (Const(v1, _), Const(v2, _)) => v1 == v2
case (Const(_, tp1), Typ(tp2, _)) => isSubType(tp1, tp2) || tryDecompose2(tp2)
case (Const(_, _), Or(ss)) => ss.exists(isSubspace(a, _))
case (Const(_, _), _) => false
case (_, Const(_, _)) => false
case (Var(x, _), Var(y, _)) => x == y
case (Var(_, tp1), Typ(tp2, _)) => isSubType(tp1, tp2) || tryDecompose2(tp2)
case (Var(_, _), Or(ss)) => ss.exists(isSubspace(a, _))
case (Var(_, _), _) => false
case (_, Var(_, _)) => false
}

debug.println(s"${show(a)} < ${show(b)} = $res")

res
}

/** Intersection of two spaces */
def intersect(a: Space, b: Space): Space = {
def tryDecompose1(tp: Type) = intersect(Or(decompose(tp)), b)
def tryDecompose2(tp: Type) = intersect(a, Or(decompose(tp)))

(a, b) match {
val res = (a, b) match {
case (Empty, _) | (_, Empty) => Empty
case (_, Or(ss)) => Or(ss.map(intersect(a, _)).filterConserve(_ ne Empty))
case (Or(ss), _) => Or(ss.map(intersect(_, b)).filterConserve(_ ne Empty))
Expand All @@ -199,39 +186,19 @@ trait SpaceLogic {
if (!isEqualType(tp1, tp2)) Empty
else if (ss1.zip(ss2).exists(p => simplify(intersect(p._1, p._2)) == Empty)) Empty
else Kon(tp1, ss1.zip(ss2).map((intersect _).tupled))
case (Const(v1, _), Const(v2, _)) =>
if (v1 == v2) a else Empty
case (Const(_, tp1), Typ(tp2, _)) =>
if (isSubType(tp1, tp2)) a
else if (canDecompose(tp2)) tryDecompose2(tp2)
else Empty
case (Const(_, _), _) => Empty
case (Typ(tp1, _), Const(_, tp2)) =>
if (isSubType(tp2, tp1)) b
else if (canDecompose(tp1)) tryDecompose1(tp1)
else Empty
case (_, Const(_, _)) => Empty
case (Var(x, _), Var(y, _)) =>
if (x == y) a else Empty
case (Var(_, tp1), Typ(tp2, _)) =>
if (isSubType(tp1, tp2)) a
else if (canDecompose(tp2)) tryDecompose2(tp2)
else Empty
case (Var(_, _), _) => Empty
case (Typ(tp1, _), Var(_, tp2)) =>
if (isSubType(tp2, tp1)) b
else if (canDecompose(tp1)) tryDecompose1(tp1)
else Empty
case (_, Var(_, _)) => Empty
}

debug.println(s"${show(a)} & ${show(b)} = ${show(res)}")

res
}

/** The space of a not covered by b */
def minus(a: Space, b: Space): Space = {
def tryDecompose1(tp: Type) = minus(Or(decompose(tp)), b)
def tryDecompose2(tp: Type) = minus(a, Or(decompose(tp)))

(a, b) match {
val res = (a, b) match {
case (Empty, _) => Empty
case (_, Empty) => a
case (Typ(tp1, _), Typ(tp2, _)) =>
Expand Down Expand Up @@ -264,26 +231,11 @@ trait SpaceLogic {
Or(ss1.zip(ss2).map((minus _).tupled).zip(0 to ss2.length - 1).map {
case (ri, i) => Kon(tp1, ss1.updated(i, ri))
})
case (Const(v1, _), Const(v2, _)) =>
if (v1 == v2) Empty else a
case (Const(_, tp1), Typ(tp2, _)) =>
if (isSubType(tp1, tp2)) Empty
else if (canDecompose(tp2)) tryDecompose2(tp2)
else a
case (Const(_, _), _) => a
case (Typ(tp1, _), Const(_, tp2)) => // Boolean & Java enum
if (canDecompose(tp1)) tryDecompose1(tp1)
else a
case (_, Const(_, _)) => a
case (Var(x, _), Var(y, _)) =>
if (x == y) Empty else a
case (Var(_, tp1), Typ(tp2, _)) =>
if (isSubType(tp1, tp2)) Empty
else if (canDecompose(tp2)) tryDecompose2(tp2)
else a
case (Var(_, _), _) => a
case (_, Var(_, _)) => a
}

debug.println(s"${show(a)} - ${show(b)} = ${show(res)}")

res
}
}

Expand All @@ -297,19 +249,14 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
* otherwise approximate extractors to Empty
*/
def project(pat: Tree, roundUp: Boolean = true)(implicit ctx: Context): Space = pat match {
case Literal(c) => Const(c, c.tpe)
case _: BackquotedIdent => Var(pat.symbol, pat.tpe)
case Literal(c) =>
if (c.value.isInstanceOf[Symbol])
Typ(c.value.asInstanceOf[Symbol].termRef, false)
else
Typ(ConstantType(c), false)
case _: BackquotedIdent => Typ(pat.tpe, false)
case Ident(_) | Select(_, _) =>
pat.tpe.stripAnnots match {
case tp: TermRef =>
if (pat.symbol.is(Enum))
Const(Constant(pat.symbol), tp)
else if (tp.underlyingIterator.exists(_.classSymbol.is(Module)))
Typ(tp.widenTermRefExpr.stripAnnots, false)
else
Var(pat.symbol, tp)
case tp => Typ(tp, false)
}
Typ(pat.tpe.stripAnnots, false)
case Alternative(trees) => Or(trees.map(project(_, roundUp)))
case Bind(_, pat) => project(pat)
case UnApply(_, _, pats) =>
Expand Down Expand Up @@ -345,7 +292,9 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
/** Is `tp1` a subtype of `tp2`? */
def isSubType(tp1: Type, tp2: Type): Boolean = {
// check SI-9657 and tests/patmat/gadt.scala
erase(tp1) <:< erase(tp2)
val res = erase(tp1) <:< erase(tp2)
debug.println(s"${tp1.show} <:< ${tp2.show} = $res")
res
}

def isEqualType(tp1: Type, tp2: Type): Boolean = tp1 =:= tp2
Expand Down Expand Up @@ -373,29 +322,37 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
}
}

debug.println(s"candidates for ${tp.show} : [${children.map(_.show).mkString(", ")}]")

tp match {
case OrType(tp1, tp2) => List(Typ(tp1, true), Typ(tp2, true))
case _ if tp =:= ctx.definitions.BooleanType =>
List(
Const(Constant(true), ctx.definitions.BooleanType),
Const(Constant(false), ctx.definitions.BooleanType)
Typ(ConstantType(Constant(true)), true),
Typ(ConstantType(Constant(false)), true)
)
case _ if tp.classSymbol.is(Enum) =>
children.map(sym => Const(Constant(sym), tp))
children.map(sym => Typ(sym.termRef, true))
case _ =>
val parts = children.map { sym =>
if (sym.is(ModuleClass))
sym.asClass.classInfo.selfType
refine(tp, sym.sourceModule.termRef)
else if (sym.isTerm)
refine(tp, sym.termRef)
else if (sym.info.typeParams.length > 0 || tp.isInstanceOf[TypeRef])
refine(tp, sym.typeRef)
else
sym.typeRef
} filter { tpe =>
// Child class may not always be subtype of parent:
// GADT & path-dependent types
tpe <:< expose(tp)
val res = tpe <:< expose(tp)
if (!res) debug.println(s"unqualified child ousted: ${tpe.show} !< ${tp.show}")
res
}

debug.println(s"${tp.show} decomposes to [${parts.map(_.show).mkString(", ")}]")

parts.map(Typ(_, true))
}
}
Expand All @@ -409,20 +366,26 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
* `path2`, then return `path1.B`.
*/
def refine(tp1: Type, tp2: Type): Type = (tp1, tp2) match {
case (tp1: RefinedType, _) => tp1.wrapIfMember(refine(tp1.parent, tp2))
case (tp1: RefinedType, _: TypeRef) => tp1.wrapIfMember(refine(tp1.parent, tp2))
case (tp1: HKApply, _) => refine(tp1.superType, tp2)
case (TypeRef(ref1: TypeProxy, _), tp2 @ TypeRef(ref2: TypeProxy, name)) =>
if (ref1.underlying <:< ref2.underlying) TypeRef(ref1, name) else tp2
case (TypeRef(ref1: TypeProxy, _), tp2 @ TypeRef(ref2: TypeProxy, _)) =>
if (ref1.underlying <:< ref2.underlying) tp2.derivedSelect(ref1) else tp2
case (TypeRef(ref1: TypeProxy, _), tp2 @ TermRef(ref2: TypeProxy, _)) =>
if (ref1.underlying <:< ref2.underlying) tp2.derivedSelect(ref1) else tp2
case _ => tp2
}

/** Abstract sealed types, or-types, Boolean and Java enums can be decomposed */
def canDecompose(tp: Type): Boolean = {
tp.classSymbol.is(allOf(Abstract, Sealed)) ||
val res = tp.classSymbol.is(allOf(Abstract, Sealed)) ||
tp.classSymbol.is(allOf(Trait, Sealed)) ||
tp.isInstanceOf[OrType] ||
tp =:= ctx.definitions.BooleanType ||
tp.classSymbol.is(Enum)
tp.classSymbol.is(allOf(Enum, Sealed)) // Enum value doesn't have Sealed flag

debug.println(s"decomposable: ${tp.show} = $res")

res
}

/** Show friendly type name with current scope in mind
Expand Down Expand Up @@ -474,14 +437,12 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
def show(s: Space): String = {
def doShow(s: Space, mergeList: Boolean = false): String = s match {
case Empty => ""
case Const(v, _) => v.show
case Var(x, _) => x.show
case Typ(c: ConstantType, _) => c.value.show
case Typ(tp: TermRef, _) => tp.symbol.showName
case Typ(tp, decomposed) =>
val sym = tp.widen.classSymbol

if (sym.is(ModuleClass))
showType(tp)
else if (ctx.definitions.isTupleType(tp))
if (ctx.definitions.isTupleType(tp))
signature(tp).map(_ => "_").mkString("(", ", ", ")")
else if (sym.showFullName == "scala.collection.immutable.::")
if (mergeList) "_" else "List(_)"
Expand Down Expand Up @@ -523,7 +484,9 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
}

val Match(sel, cases) = tree
isCheckable(sel.tpe.widen.deAnonymize.dealiasKeepAnnots)
val res = isCheckable(sel.tpe.widen.deAnonymize.dealiasKeepAnnots)
debug.println(s"checkable: ${sel.show} = $res")
res
}


Expand Down Expand Up @@ -584,7 +547,11 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
val selTyp = sel.tpe.widen.deAnonymize.dealias


val patternSpace = cases.map(x => project(x.pat)).reduce((a, b) => Or(List(a, b)))
val patternSpace = cases.map({ x =>
val space = project(x.pat)
debug.println(s"${x.pat.show} projects to ${show(space)}")
space
}).reduce((a, b) => Or(List(a, b)))
val uncovered = simplify(minus(Typ(selTyp, true), patternSpace))

if (uncovered != Empty)
Expand Down
22 changes: 22 additions & 0 deletions tests/patmat/enum-HList.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
enum HLst {
case HCons[+Hd, +Tl <: HLst](hd: Hd, tl: Tl)
case HNil
}

object Test {
import HLst._
def length(hl: HLst): Int = hl match {
case HCons(_, tl) => 1 + length(tl)
case HNil => 0
}
def sumInts(hl: HLst): Int = hl match {
case HCons(x: Int, tl) => x + sumInts(tl)
case HCons(_, tl) => sumInts(tl)
case HNil => 0
}
def main(args: Array[String]) = {
val hl = HCons(1, HCons("A", HNil))
assert(length(hl) == 2, length(hl))
assert(sumInts(hl) == 1)
}
}
Loading