diff --git a/compiler/src/dotty/tools/dotc/typer/Nullables.scala b/compiler/src/dotty/tools/dotc/typer/Nullables.scala index 9104418d406f..5be8e0aa3060 100644 --- a/compiler/src/dotty/tools/dotc/typer/Nullables.scala +++ b/compiler/src/dotty/tools/dotc/typer/Nullables.scala @@ -190,6 +190,16 @@ object Nullables: // TODO: Add constant pattern if the constant type is not nullable case _ => false + def matchesNull(cdef: CaseDef)(using Context): Boolean = + cdef.guard.isEmpty && patMatchesNull(cdef.pat) + + private def patMatchesNull(pat: Tree)(using Context): Boolean = pat match + case Literal(Constant(null)) => true + case Bind(_, pat) => patMatchesNull(pat) + case Alternative(trees) => trees.exists(patMatchesNull) + case _ if isVarPattern(pat) => true + case _ => false + extension (infos: List[NotNullInfo]) /** Do the current not-null infos imply that `ref` is not null? diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index cb5051ea34ad..44b9844916bd 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1843,12 +1843,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer /** Special typing of Match tree when the expected type is a MatchType, * and the patterns of the Match tree and the MatchType correspond. */ - def typedDependentMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: MatchType)(using Context): Tree = { + def typedDependentMatchFinish(tree: untpd.Match, sel: Tree, wideSelType0: Type, cases: List[untpd.CaseDef], pt: MatchType)(using Context): Tree = { var caseCtx = ctx + var wideSelType = wideSelType0 + var alreadyStripped = false val cases1 = tree.cases.zip(pt.cases) .map { case (cas, tpe) => val case1 = typedCase(cas, sel, wideSelType, tpe)(using caseCtx) caseCtx = Nullables.afterPatternContext(sel, case1.pat) + if !alreadyStripped && Nullables.matchesNull(case1) then + wideSelType = wideSelType.stripNull + alreadyStripped = true case1 } .asInstanceOf[List[CaseDef]] @@ -1862,11 +1867,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer assignType(cpy.Match(tree)(sel, cases1), sel, cases1) } - def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType: Type, pt: Type)(using Context): List[CaseDef] = + def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] = var caseCtx = ctx + var wideSelType = wideSelType0 + var alreadyStripped = false cases.mapconserve { cas => val case1 = typedCase(cas, sel, wideSelType, pt)(using caseCtx) caseCtx = Nullables.afterPatternContext(sel, case1.pat) + if !alreadyStripped && Nullables.matchesNull(case1) then + wideSelType = wideSelType.stripNull + alreadyStripped = true case1 } diff --git a/tests/explicit-nulls/neg/flow-match.scala b/tests/explicit-nulls/neg/flow-match.scala new file mode 100644 index 000000000000..e385758261cd --- /dev/null +++ b/tests/explicit-nulls/neg/flow-match.scala @@ -0,0 +1,15 @@ +// Test flow-typing when NotNullInfos are from cases + +object MatchTest { + def f6(s: String | Null): String = s match { + case s2 => s2 // error + case null => "other" // error + case s3 => s3 + } + + def f7(s: String | Null): String = s match { + case null => "other" + case null => "other" // error + case s3 => s3 + } +} diff --git a/tests/explicit-nulls/pos/flow-match.scala b/tests/explicit-nulls/pos/flow-match.scala index 260068b3ac3f..57e2c12b3c68 100644 --- a/tests/explicit-nulls/pos/flow-match.scala +++ b/tests/explicit-nulls/pos/flow-match.scala @@ -12,4 +12,47 @@ object MatchTest { // after the null case, s becomes non-nullable case _ => s } + + def f(s: String | Null): String = s match { + case null => "other" + case s2 => s2 + case s3 => s3 + } + + class Foo + + def f2(s: String | Null): String = s match { + case n @ null => "other" + case s2 => s2 + case s3 => s3 + } + + def f3(s: String | Null): String = s match { + case null | "foo" => "other" + case s2 => s2 + case s3 => s3 + } + + def f4(s: String | Null): String = s match { + case _ => "other" + case s2 => s2 + case s3 => s3 + } + + def f5(s: String | Null): String = s match { + case x => "other" + case s2 => s2 + case s3 => s3 + } + + def f6(s: String | Null): String = s match { + case s3: String => s3 + case null => "other" + case s4 => s4 + } + + def f7(s: String | Null): String = s match { + case s2 => s2.nn + case s3 => s3 + } }