From 1572968bfe9963255c8240bec95aa0a726e65778 Mon Sep 17 00:00:00 2001 From: Ondrej Lhotak Date: Fri, 14 Jul 2023 16:51:02 -0400 Subject: [PATCH 1/2] fix #11967: flow typing nullability in pattern matches --- .../dotty/tools/dotc/typer/Nullables.scala | 10 ++++++ .../src/dotty/tools/dotc/typer/Typer.scala | 14 ++++++-- tests/explicit-nulls/pos/flow-match.scala | 32 +++++++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Nullables.scala b/compiler/src/dotty/tools/dotc/typer/Nullables.scala index 9104418d406f..5be8e0aa3060 100644 --- a/compiler/src/dotty/tools/dotc/typer/Nullables.scala +++ b/compiler/src/dotty/tools/dotc/typer/Nullables.scala @@ -190,6 +190,16 @@ object Nullables: // TODO: Add constant pattern if the constant type is not nullable case _ => false + def matchesNull(cdef: CaseDef)(using Context): Boolean = + cdef.guard.isEmpty && patMatchesNull(cdef.pat) + + private def patMatchesNull(pat: Tree)(using Context): Boolean = pat match + case Literal(Constant(null)) => true + case Bind(_, pat) => patMatchesNull(pat) + case Alternative(trees) => trees.exists(patMatchesNull) + case _ if isVarPattern(pat) => true + case _ => false + extension (infos: List[NotNullInfo]) /** Do the current not-null infos imply that `ref` is not null? diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index cb5051ea34ad..0cb1361618ef 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1843,12 +1843,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer /** Special typing of Match tree when the expected type is a MatchType, * and the patterns of the Match tree and the MatchType correspond. */ - def typedDependentMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: MatchType)(using Context): Tree = { + def typedDependentMatchFinish(tree: untpd.Match, sel: Tree, wideSelType0: Type, cases: List[untpd.CaseDef], pt: MatchType)(using Context): Tree = { var caseCtx = ctx + var wideSelType = wideSelType0 + var alreadyStripped = false val cases1 = tree.cases.zip(pt.cases) .map { case (cas, tpe) => val case1 = typedCase(cas, sel, wideSelType, tpe)(using caseCtx) caseCtx = Nullables.afterPatternContext(sel, case1.pat) + if !alreadyStripped && Nullables.matchesNull(case1) then + wideSelType = wideSelType.stripNull + alreadyStripped = true case1 } .asInstanceOf[List[CaseDef]] @@ -1862,10 +1867,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer assignType(cpy.Match(tree)(sel, cases1), sel, cases1) } - def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType: Type, pt: Type)(using Context): List[CaseDef] = + def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] = var caseCtx = ctx + var wideSelType = wideSelType0 + var alreadyStripped = false cases.mapconserve { cas => val case1 = typedCase(cas, sel, wideSelType, pt)(using caseCtx) + if !alreadyStripped && Nullables.matchesNull(case1) then + wideSelType = wideSelType.stripNull + alreadyStripped = true caseCtx = Nullables.afterPatternContext(sel, case1.pat) case1 } diff --git a/tests/explicit-nulls/pos/flow-match.scala b/tests/explicit-nulls/pos/flow-match.scala index 260068b3ac3f..2ed746be81b5 100644 --- a/tests/explicit-nulls/pos/flow-match.scala +++ b/tests/explicit-nulls/pos/flow-match.scala @@ -12,4 +12,36 @@ object MatchTest { // after the null case, s becomes non-nullable case _ => s } + + def f(s: String | Null): String = s match { + case null => "other" + case s2 => s2 + case s3 => s3 + } + + class Foo + + def f2(s: String | Null): String = s match { + case n @ null => "other" + case s2 => s2 + case s3 => s3 + } + + def f3(s: String | Null): String = s match { + case null | "foo" => "other" + case s2 => s2 + case s3 => s3 + } + + def f4(s: String | Null): String = s match { + case _ => "other" + case s2 => s2 + case s3 => s3 + } + + def f5(s: String | Null): String = s match { + case x => "other" + case s2 => s2 + case s3 => s3 + } } From 91438892b2551234c55e4c02bfc542dc56ed54de Mon Sep 17 00:00:00 2001 From: Ondrej Lhotak Date: Sat, 15 Jul 2023 17:03:01 -0400 Subject: [PATCH 2/2] address review comments --- compiler/src/dotty/tools/dotc/typer/Typer.scala | 2 +- tests/explicit-nulls/neg/flow-match.scala | 15 +++++++++++++++ tests/explicit-nulls/pos/flow-match.scala | 11 +++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 tests/explicit-nulls/neg/flow-match.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 0cb1361618ef..44b9844916bd 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1873,10 +1873,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer var alreadyStripped = false cases.mapconserve { cas => val case1 = typedCase(cas, sel, wideSelType, pt)(using caseCtx) + caseCtx = Nullables.afterPatternContext(sel, case1.pat) if !alreadyStripped && Nullables.matchesNull(case1) then wideSelType = wideSelType.stripNull alreadyStripped = true - caseCtx = Nullables.afterPatternContext(sel, case1.pat) case1 } diff --git a/tests/explicit-nulls/neg/flow-match.scala b/tests/explicit-nulls/neg/flow-match.scala new file mode 100644 index 000000000000..e385758261cd --- /dev/null +++ b/tests/explicit-nulls/neg/flow-match.scala @@ -0,0 +1,15 @@ +// Test flow-typing when NotNullInfos are from cases + +object MatchTest { + def f6(s: String | Null): String = s match { + case s2 => s2 // error + case null => "other" // error + case s3 => s3 + } + + def f7(s: String | Null): String = s match { + case null => "other" + case null => "other" // error + case s3 => s3 + } +} diff --git a/tests/explicit-nulls/pos/flow-match.scala b/tests/explicit-nulls/pos/flow-match.scala index 2ed746be81b5..57e2c12b3c68 100644 --- a/tests/explicit-nulls/pos/flow-match.scala +++ b/tests/explicit-nulls/pos/flow-match.scala @@ -44,4 +44,15 @@ object MatchTest { case s2 => s2 case s3 => s3 } + + def f6(s: String | Null): String = s match { + case s3: String => s3 + case null => "other" + case s4 => s4 + } + + def f7(s: String | Null): String = s match { + case s2 => s2.nn + case s3 => s3 + } }