Skip to content

Commit 8019010

Browse files
loic-simonLoïc Simon
and
Loïc Simon
authored
Narrow individual items when matching a tuple to a sequence pattern (#16905)
Fixes #12364 When matching a tuple to a sequence pattern, this change narrows the type of tuple items inside the matched case: ```py def test(a: bool, b: bool) -> None: match a, b: case True, True: reveal_type(a) # before: "builtins.bool", after: "Literal[True]" ``` This also works with nested tuples, recursively: ```py def test(a: bool, b: bool, c: bool) -> None: match a, (b, c): case _, [True, False]: reveal_type(c) # before: "builtins.bool", after: "Literal[False]" ``` This only partially fixes issue #12364; see [my comment there](#12364 (comment)) for more context. --- This is my first contribution to mypy, so I may miss some context or conventions; I'm eager for any feedback! --------- Co-authored-by: Loïc Simon <[email protected]>
1 parent ec44015 commit 8019010

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

mypy/checker.py

+17
Original file line numberDiff line numberDiff line change
@@ -5119,6 +5119,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
51195119
)
51205120
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
51215121
self.push_type_map(pattern_map)
5122+
if pattern_map:
5123+
for expr, typ in pattern_map.items():
5124+
self.push_type_map(self._get_recursive_sub_patterns_map(expr, typ))
51225125
self.push_type_map(pattern_type.captures)
51235126
if g is not None:
51245127
with self.binder.frame_context(can_skip=False, fall_through=3):
@@ -5156,6 +5159,20 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
51565159
with self.binder.frame_context(can_skip=False, fall_through=2):
51575160
pass
51585161

5162+
def _get_recursive_sub_patterns_map(
5163+
self, expr: Expression, typ: Type
5164+
) -> dict[Expression, Type]:
5165+
sub_patterns_map: dict[Expression, Type] = {}
5166+
typ_ = get_proper_type(typ)
5167+
if isinstance(expr, TupleExpr) and isinstance(typ_, TupleType):
5168+
# When matching a tuple expression with a sequence pattern, narrow individual tuple items
5169+
assert len(expr.items) == len(typ_.items)
5170+
for item_expr, item_typ in zip(expr.items, typ_.items):
5171+
sub_patterns_map[item_expr] = item_typ
5172+
sub_patterns_map.update(self._get_recursive_sub_patterns_map(item_expr, item_typ))
5173+
5174+
return sub_patterns_map
5175+
51595176
def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]:
51605177
all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list)
51615178
for tm in type_maps:

test-data/unit/check-python310.test

+66
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,72 @@ match m:
341341
reveal_type(m) # N: Revealed type is "builtins.list[builtins.list[builtins.str]]"
342342
[builtins fixtures/list.pyi]
343343

344+
[case testMatchSequencePatternNarrowSubjectItems]
345+
m: int
346+
n: str
347+
o: bool
348+
349+
match m, n, o:
350+
case [3, "foo", True]:
351+
reveal_type(m) # N: Revealed type is "Literal[3]"
352+
reveal_type(n) # N: Revealed type is "Literal['foo']"
353+
reveal_type(o) # N: Revealed type is "Literal[True]"
354+
case [a, b, c]:
355+
reveal_type(m) # N: Revealed type is "builtins.int"
356+
reveal_type(n) # N: Revealed type is "builtins.str"
357+
reveal_type(o) # N: Revealed type is "builtins.bool"
358+
359+
reveal_type(m) # N: Revealed type is "builtins.int"
360+
reveal_type(n) # N: Revealed type is "builtins.str"
361+
reveal_type(o) # N: Revealed type is "builtins.bool"
362+
[builtins fixtures/tuple.pyi]
363+
364+
[case testMatchSequencePatternNarrowSubjectItemsRecursive]
365+
m: int
366+
n: int
367+
o: int
368+
p: int
369+
q: int
370+
r: int
371+
372+
match m, (n, o), (p, (q, r)):
373+
case [0, [1, 2], [3, [4, 5]]]:
374+
reveal_type(m) # N: Revealed type is "Literal[0]"
375+
reveal_type(n) # N: Revealed type is "Literal[1]"
376+
reveal_type(o) # N: Revealed type is "Literal[2]"
377+
reveal_type(p) # N: Revealed type is "Literal[3]"
378+
reveal_type(q) # N: Revealed type is "Literal[4]"
379+
reveal_type(r) # N: Revealed type is "Literal[5]"
380+
[builtins fixtures/tuple.pyi]
381+
382+
[case testMatchSequencePatternSequencesLengthMismatchNoNarrowing]
383+
m: int
384+
n: str
385+
o: bool
386+
387+
match m, n, o:
388+
case [3, "foo"]:
389+
pass
390+
case [3, "foo", True, True]:
391+
pass
392+
[builtins fixtures/tuple.pyi]
393+
394+
[case testMatchSequencePatternSequencesLengthMismatchNoNarrowingRecursive]
395+
m: int
396+
n: int
397+
o: int
398+
399+
match m, (n, o):
400+
case [0]:
401+
pass
402+
case [0, 1, [2]]:
403+
pass
404+
case [0, [1]]:
405+
pass
406+
case [0, [1, 2, 3]]:
407+
pass
408+
[builtins fixtures/tuple.pyi]
409+
344410
-- Mapping Pattern --
345411

346412
[case testMatchMappingPatternCaptures]

0 commit comments

Comments
 (0)