Skip to content

Commit 2314852

Browse files
committed
Improve inference of union of generic types when one of the types is the generic type of the other
1 parent b1d5b92 commit 2314852

File tree

2 files changed

+84
-12
lines changed

2 files changed

+84
-12
lines changed

mypy/constraints.py

+58-11
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,11 @@ def infer_constraints_for_callable(
271271

272272

273273
def infer_constraints(
274-
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
274+
template: Type,
275+
actual: Type,
276+
direction: int,
277+
skip_neg_op: bool = False,
278+
can_have_union_overlaping: bool = True,
275279
) -> list[Constraint]:
276280
"""Infer type constraints.
277281
@@ -311,11 +315,15 @@ def infer_constraints(
311315
res = _infer_constraints(template, actual, direction, skip_neg_op)
312316
type_state.inferring.pop()
313317
return res
314-
return _infer_constraints(template, actual, direction, skip_neg_op)
318+
return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlaping)
315319

316320

317321
def _infer_constraints(
318-
template: Type, actual: Type, direction: int, skip_neg_op: bool
322+
template: Type,
323+
actual: Type,
324+
direction: int,
325+
skip_neg_op: bool,
326+
can_have_union_overlaping: bool = True,
319327
) -> list[Constraint]:
320328
orig_template = template
321329
template = get_proper_type(template)
@@ -368,8 +376,41 @@ def _infer_constraints(
368376
return res
369377
if direction == SUPERTYPE_OF and isinstance(actual, UnionType):
370378
res = []
379+
380+
def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool:
381+
# There is a special overlaping case, where we have a Union of where two types
382+
# are the same, but one of them contains the other.
383+
# For example, we have Union[Sequence[T], Sequence[Sequence[T]]]
384+
# In this case, only the second one can have overlaping because it contains the other.
385+
# So, in case of list[list[int]], second one would be chosen.
386+
if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args:
387+
other_items = [o_item for o_item in _actual.items if o_item is not a_item]
388+
389+
if len(other_items) == 1 and other_items[0] in p_item.args:
390+
return True
391+
392+
if len(other_items) > 1:
393+
union_args = [
394+
p_arg
395+
for arg in p_item.args
396+
if isinstance(p_arg := get_proper_type(arg), UnionType)
397+
]
398+
399+
for union_arg in union_args:
400+
if all(o_item in union_arg.items for o_item in other_items):
401+
return True
402+
403+
return False
404+
371405
for a_item in actual.items:
372-
res.extend(infer_constraints(orig_template, a_item, direction))
406+
res.extend(
407+
infer_constraints(
408+
orig_template,
409+
a_item,
410+
direction,
411+
can_have_union_overlaping=_can_have_overlaping(a_item, actual),
412+
)
413+
)
373414
return res
374415

375416
# Now the potential subtype is known not to be a Union or a type
@@ -391,22 +432,28 @@ def _infer_constraints(
391432
# type variables indeterminate. This helps with some special
392433
# cases, though this isn't very principled.
393434

394-
def _is_item_being_overlaped_by_other(item: Type) -> bool:
395-
# It returns true if the item is an argument of other item
435+
def _is_item_overlaping_actual_type(_item: Type) -> bool:
436+
# Overlaping occurs when we have a Union where two types are
437+
# compatible and the more generic one is chosen.
438+
# For example, in Union[T, Sequence[T]], we have to choose
439+
# Sequence[T] if actual type is list[int].
440+
# This returns true if the item is an argument of other item
396441
# that is subtype of the actual type
397442
return any(
398-
isinstance(p_type := get_proper_type(item_to_compare), Instance)
399-
and mypy.subtypes.is_subtype(actual, erase_typevars(p_type))
400-
and item in p_type.args
443+
isinstance(p_item_to_compare := get_proper_type(item_to_compare), Instance)
444+
and mypy.subtypes.is_subtype(actual, erase_typevars(p_item_to_compare))
445+
and _item in p_item_to_compare.args
401446
for item_to_compare in template.items
402-
if item is not item_to_compare
447+
if _item is not item_to_compare
403448
)
404449

405450
result = any_constraints(
406451
[
407452
infer_constraints_if_possible(t_item, actual, direction)
408453
for t_item in [
409-
item for item in template.items if not _is_item_being_overlaped_by_other(item)
454+
item
455+
for item in template.items
456+
if not (can_have_union_overlaping and _is_item_overlaping_actual_type(item))
410457
]
411458
],
412459
eager=False,

test-data/unit/check-inference.test

+26-1
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ i(b, a, b)
885885
i(a, b, b) # E: Argument 1 to "i" has incompatible type "List[int]"; expected "List[str]"
886886
[builtins fixtures/list.pyi]
887887

888-
[case testUnionInferenceOfGenericClassAndItsGenericType]
888+
[case testInferenceOfUnionOfGenericClassAndItsGenericType]
889889
from typing import Generic, TypeVar, Union
890890

891891
T = TypeVar('T')
@@ -907,6 +907,31 @@ reveal_type(result_2) # N: Revealed type is "__main__.GenericClass[builtins.str]
907907

908908
[builtins fixtures/isinstance.pyi]
909909

910+
[case testInferenceOfUnionOfSequenceOfAnyAndSequenceOfSequence]
911+
from typing import Sequence, Iterable, TypeVar, Union
912+
913+
T = TypeVar("T")
914+
S = TypeVar("S")
915+
916+
def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]:
917+
pass
918+
919+
def method(value: Union[Sequence[T], Sequence[Sequence[T]]]) -> None:
920+
reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[typing.Sequence[T`-1]]"
921+
922+
[case testInferenceOfUnionOfUnionWithSequenceAndSequenceOfThatUnion]
923+
from typing import Sequence, Iterable, TypeVar, Union
924+
925+
T = Union[str, Sequence[int]]
926+
S = TypeVar("S", bound=T)
927+
928+
def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]:
929+
pass
930+
931+
def method(value: Union[T, Sequence[T]]) -> None:
932+
reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[Union[builtins.str, typing.Sequence[builtins.int]]]"
933+
934+
910935
[case testCallableListJoinInference]
911936
from typing import Any, Callable
912937

0 commit comments

Comments
 (0)