Skip to content

Commit 7b5e964

Browse files
committed
Track nullability in pattern matches
1 parent 3ac0fb1 commit 7b5e964

File tree

4 files changed

+43
-10
lines changed

4 files changed

+43
-10
lines changed

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,9 @@ class TreeChecker extends Phase with SymTransformer {
434434
}
435435
}
436436

437-
override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef =
437+
override def typedCase(tree: untpd.CaseDef, sel: Tree, selType: Type, pt: Type)(implicit ctx: Context): CaseDef =
438438
withPatSyms(tpd.patVars(tree.pat.asInstanceOf[tpd.Tree])) {
439-
super.typedCase(tree, selType, pt)
439+
super.typedCase(tree, sel, selType, pt)
440440
}
441441

442442
override def typedClosure(tree: untpd.Closure, pt: Type)(implicit ctx: Context): Tree = {

compiler/src/dotty/tools/dotc/typer/Nullables.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ object Nullables with
6565
/** Is given reference tracked for nullability? */
6666
def isTracked(ref: TermRef)(given Context) = ref.isStable
6767

68+
def afterPatternContext(sel: Tree, pat: Tree)(given ctx: Context) = (sel, pat) match
69+
case (TrackedRef(ref), Literal(Constant(null))) => ctx.addExcluded(Set(ref))
70+
case _ => ctx
71+
72+
def caseContext(sel: Tree, pat: Tree)(given ctx: Context): Context = sel match
73+
case TrackedRef(ref) if matchesNotNull(pat) => ctx.addExcluded(Set(ref))
74+
case _ => ctx
75+
76+
private def matchesNotNull(pat: Tree)(given Context): Boolean = pat match
77+
case _: Typed | _: UnApply => true
78+
case Alternative(pats) => pats.forall(matchesNotNull)
79+
// TODO: Add constant pattern if the constant type is not nullable
80+
case _ => false
81+
6882
given (excluded: List[Excluded])
6983
def containsRef(ref: TermRef): Boolean =
7084
excluded.exists(_.contains(ref))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,13 +1138,18 @@ class Typer extends Namer
11381138

11391139
// Overridden in InlineTyper for inline matches
11401140
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(implicit ctx: Context): Tree = {
1141-
val cases1 = harmonic(harmonize, pt)(typedCases(cases, wideSelType, pt.dropIfProto))
1141+
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
11421142
.asInstanceOf[List[CaseDef]]
11431143
assignType(cpy.Match(tree)(sel, cases1), sel, cases1)
11441144
}
11451145

1146-
def typedCases(cases: List[untpd.CaseDef], selType: Type, pt: Type)(implicit ctx: Context): List[CaseDef] =
1147-
cases.mapconserve(typedCase(_, selType, pt))
1146+
def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType: Type, pt: Type)(implicit ctx: Context): List[CaseDef] =
1147+
var caseCtx = ctx
1148+
cases.mapconserve { cas =>
1149+
val case1 = typedCase(cas, sel, wideSelType, pt)(given caseCtx)
1150+
caseCtx = Nullables.afterPatternContext(sel, case1.pat)
1151+
case1
1152+
}
11481153

11491154
/** - strip all instantiated TypeVars from pattern types.
11501155
* run/reducable.scala is a test case that shows stripping typevars is necessary.
@@ -1171,7 +1176,7 @@ class Typer extends Namer
11711176
}
11721177

11731178
/** Type a case. */
1174-
def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = {
1179+
def typedCase(tree: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(implicit ctx: Context): CaseDef = {
11751180
val originalCtx = ctx
11761181
val gadtCtx: Context = ctx.fresh.setFreshGADTBounds
11771182

@@ -1184,8 +1189,10 @@ class Typer extends Namer
11841189
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1)
11851190
}
11861191

1187-
val pat1 = typedPattern(tree.pat, selType)(gadtCtx)
1188-
caseRest(pat1)(gadtCtx.fresh.setNewScope)
1192+
val pat1 = typedPattern(tree.pat, wideSelType)(gadtCtx)
1193+
caseRest(pat1)(
1194+
given Nullables.caseContext(sel, pat1)(
1195+
given gadtCtx.fresh.setNewScope))
11891196
}
11901197

11911198
def typedLabeled(tree: untpd.Labeled)(implicit ctx: Context): Labeled = {
@@ -1205,7 +1212,6 @@ class Typer extends Namer
12051212
caseRest(ctx.fresh.setFreshGADTBounds.setNewScope)
12061213
}
12071214

1208-
12091215
def typedReturn(tree: untpd.Return)(implicit ctx: Context): Return = {
12101216
def returnProto(owner: Symbol, locals: Scope): Type =
12111217
if (owner.isConstructor) defn.UnitType
@@ -1263,7 +1269,7 @@ class Typer extends Namer
12631269
def typedTry(tree: untpd.Try, pt: Type)(implicit ctx: Context): Try = {
12641270
val expr2 :: cases2x = harmonic(harmonize, pt) {
12651271
val expr1 = typed(tree.expr, pt.dropIfProto)
1266-
val cases1 = typedCases(tree.cases, defn.ThrowableType, pt.dropIfProto)
1272+
val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto)
12671273
expr1 :: cases1
12681274
}
12691275
val finalizer1 = typed(tree.finalizer, defn.UnitType)

tests/pos/nullable.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ def test: Unit =
3030
}
3131
then ()
3232

33+
x match
34+
case _: String =>
35+
if x == null then impossible(new T{})
36+
37+
val y: Any = List(x)
38+
y match
39+
case y1 :: ys => if y == null then impossible(new T{})
40+
case Some(_) | Seq(_: _*) => if y == null then impossible(new T{})
41+
42+
x match
43+
case null =>
44+
case _ => if x == null then impossible(new T{})
45+
3346
if x == null then return
3447
if x == null then impossible(new T{})
3548

0 commit comments

Comments
 (0)