From b3ec408026ac9581e4058e93d2e1c3a32b84ccfe Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sat, 11 Jan 2025 22:15:29 +0100 Subject: [PATCH 1/4] Create fake named expressions for `match` subject in more cases --- mypy/checker.py | 59 +++++++++++++++----- test-data/unit/check-python310.test | 85 ++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 14 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 80de4254766b..0560edcca00c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -65,10 +65,12 @@ CallExpr, ClassDef, ComparisonExpr, + ComplexExpr, Context, ContinueStmt, Decorator, DelStmt, + DictExpr, EllipsisExpr, Expression, ExpressionStmt, @@ -100,6 +102,7 @@ RaiseStmt, RefExpr, ReturnStmt, + SetExpr, StarExpr, Statement, StrExpr, @@ -350,6 +353,9 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): # functions such as open(), etc. plugin: Plugin + # A helper state to produce unique temporary names on demand. + _unique_id: int + def __init__( self, errors: Errors, @@ -413,6 +419,7 @@ def __init__( self, self.msg, self.plugin, per_line_checking_time_ns ) self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options) + self._unique_id = 0 @property def type_context(self) -> list[Type | None]: @@ -5273,19 +5280,7 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return def visit_match_stmt(self, s: MatchStmt) -> None: - named_subject: Expression - if isinstance(s.subject, CallExpr): - # Create a dummy subject expression to handle cases where a match statement's subject - # is not a literal value. This lets us correctly narrow types and check exhaustivity - # This is hack! - id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" - name = "dummy-match-" + id - v = Var(name) - named_subject = NameExpr(name) - named_subject.node = v - else: - named_subject = s.subject - + named_subject = self._make_named_statement_for_match(s.subject) with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) @@ -5362,6 +5357,38 @@ def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=2): pass + def _make_named_statement_for_match(self, subject: Expression) -> Expression: + """Construct a fake NameExpr for inference if a match clause is complex.""" + expressions_to_preserve = ( + # Already named - we should infer type of it as given + NameExpr, + AssignmentExpr, + # Collection literals defined inline - we want to infer types of variables + # included there, not exprs as a whole + ListExpr, + DictExpr, + TupleExpr, + SetExpr, + # Primitive literals - their type is known, no need to name them + IntExpr, + StrExpr, + BytesExpr, + FloatExpr, + ComplexExpr, + EllipsisExpr, + ) + if isinstance(subject, expressions_to_preserve): + return subject + else: + # Create a dummy subject expression to handle cases where a match statement's subject + # is not a literal value. This lets us correctly narrow types and check exhaustivity + # This is hack! + name = self.new_unique_dummy_name("match") + v = Var(name) + named_subject = NameExpr(name) + named_subject.node = v + return named_subject + def _get_recursive_sub_patterns_map( self, expr: Expression, typ: Type ) -> dict[Expression, Type]: @@ -7715,6 +7742,12 @@ def warn_deprecated_overload_item( if candidate == target: self.warn_deprecated(item.func, context) + def new_unique_dummy_name(self, namespace: str) -> str: + """Generate a name that is guaranteed to be unique for this TypeChecker instance.""" + name = f"dummy-{namespace}-{self._unique_id}" + self._unique_id += 1 + return name + class CollectArgTypeVarTypes(TypeTraverserVisitor): """Collects the non-nested argument types in a set.""" diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 616846789c98..20d4fb057de2 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1239,7 +1239,7 @@ def main() -> None: case a: reveal_type(a) # N: Revealed type is "builtins.int" -[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail] +[case testMatchCapturePatternFromAsyncFunctionReturningUnion] async def func1(arg: bool) -> str | int: ... async def func2(arg: bool) -> bytes | int: ... @@ -2439,3 +2439,86 @@ def foo(x: T) -> T: return out [builtins fixtures/isinstance.pyi] + +[case testMatchFunctionCall] +# flags: --warn-unreachable + +def fn() -> int | str: ... + +match fn(): + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchAttribute] +# flags: --warn-unreachable + +class A: + foo: int | str + +match A().foo: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchOperations] +# flags: --warn-unreachable + +x: int +match -x: + case -1 as s: + reveal_type(s) # N: Revealed type is "Literal[-1]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 + 2: + case 3 as s: + reveal_type(s) # N: Revealed type is "Literal[3]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 > 2: + case True as s: + reveal_type(s) # N: Revealed type is "Literal[True]" + case False as s: + reveal_type(s) # N: Revealed type is "Literal[False]" + case other: + other # E: Statement is unreachable +[builtins fixtures/ops.pyi] + +[case testMatchDictItem] +# flags: --warn-unreachable + +m: dict[str, int | str] +k: str + +match m[k]: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[builtins fixtures/dict.pyi] + +[case testMatchLiteralValuePathological] +# flags: --warn-unreachable + +match 0: + case 0 as i: + reveal_type(i) # N: Revealed type is "Literal[0]?" + case int(i): + i # E: Statement is unreachable + case other: + other # E: Statement is unreachable From 5e0b2594c07ed54661db7f034e3410d6d0c4a006 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sat, 5 Apr 2025 00:40:51 +0200 Subject: [PATCH 2/4] Add the original subject to typemap for inference --- mypy/checker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 13985b9cfdaa..0f48ef3be4c6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5565,6 +5565,10 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_map, else_map = conditional_types_to_typemaps( named_subject, pattern_type.type, pattern_type.rest_type ) + if pattern_map and named_subject in pattern_map: + pattern_map[s.subject] = pattern_map[named_subject] + if else_map and named_subject in else_map: + else_map[s.subject] = else_map[named_subject] pattern_map = self.propagate_up_typemap_info(pattern_map) else_map = self.propagate_up_typemap_info(else_map) self.remove_capture_conflicts(pattern_type.captures, inferred_types) From 85d16bd8d811f8f03b3943daef36f4c302cc1b49 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sat, 5 Apr 2025 00:48:07 +0200 Subject: [PATCH 3/4] Add comment --- mypy/checker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 7223c806739c..3debd929aab0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5452,6 +5452,8 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_map, else_map = conditional_types_to_typemaps( named_subject, pattern_type.type, pattern_type.rest_type ) + # Maybe the subject type can be inferred from constraints on + # its attribute/item? if pattern_map and named_subject in pattern_map: pattern_map[s.subject] = pattern_map[named_subject] if else_map and named_subject in else_map: From ded98cb649bae1140e3298ea8f142bbfe2d07bc9 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Sat, 5 Apr 2025 04:05:07 +0200 Subject: [PATCH 4/4] And we can keep inline collection literals too now since we infer both --- mypy/checker.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 3debd929aab0..ccbed78d49ff 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -72,7 +72,6 @@ ContinueStmt, Decorator, DelStmt, - DictExpr, EllipsisExpr, Expression, ExpressionStmt, @@ -106,7 +105,6 @@ RaiseStmt, RefExpr, ReturnStmt, - SetExpr, StarExpr, Statement, StrExpr, @@ -5512,12 +5510,6 @@ def _make_named_statement_for_match(self, s: MatchStmt) -> Expression: # Already named - we should infer type of it as given NameExpr, AssignmentExpr, - # Collection literals defined inline - we want to infer types of variables - # included there, not exprs as a whole - ListExpr, - DictExpr, - TupleExpr, - SetExpr, # Primitive literals - their type is known, no need to name them IntExpr, StrExpr,