diff --git a/mypy/constraints.py b/mypy/constraints.py index d88b722aa1ce..531de8b198e1 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -276,7 +276,11 @@ def infer_constraints_for_callable( def infer_constraints( - template: Type, actual: Type, direction: int, skip_neg_op: bool = False + template: Type, + actual: Type, + direction: int, + skip_neg_op: bool = False, + can_have_union_overlapping: bool = True, ) -> list[Constraint]: """Infer type constraints. @@ -316,11 +320,15 @@ def infer_constraints( res = _infer_constraints(template, actual, direction, skip_neg_op) type_state.inferring.pop() return res - return _infer_constraints(template, actual, direction, skip_neg_op) + return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlapping) def _infer_constraints( - template: Type, actual: Type, direction: int, skip_neg_op: bool + template: Type, + actual: Type, + direction: int, + skip_neg_op: bool, + can_have_union_overlapping: bool = True, ) -> list[Constraint]: orig_template = template template = get_proper_type(template) @@ -383,13 +391,46 @@ def _infer_constraints( return res if direction == SUPERTYPE_OF and isinstance(actual, UnionType): res = [] + + def _can_have_overlapping(_item: Type, _actual: UnionType) -> bool: + # There is a special overlapping case, where we have a Union of where two types + # are the same, but one of them contains the other. + # For example, we have Union[Sequence[T], Sequence[Sequence[T]]] + # In this case, only the second one can have overlapping because it contains the other. + # So, in case of list[list[int]], second one would be chosen. + if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args: + other_items = [o_item for o_item in _actual.items if o_item is not a_item] + + if len(other_items) == 1 and other_items[0] in p_item.args: + return True + + if len(other_items) > 1: + union_args = [ + p_arg + for arg in p_item.args + if isinstance(p_arg := get_proper_type(arg), UnionType) + ] + + for union_arg in union_args: + if all(o_item in union_arg.items for o_item in other_items): + return True + + return False + for a_item in actual.items: # `orig_template` has to be preserved intact in case it's recursive. # If we unwrapped ``type[...]`` previously, wrap the item back again, # as ``type[...]`` can't be removed from `orig_template`. if type_type_unwrapped: a_item = TypeType.make_normalized(a_item) - res.extend(infer_constraints(orig_template, a_item, direction)) + res.extend( + infer_constraints( + orig_template, + a_item, + direction, + can_have_union_overlapping=_can_have_overlapping(a_item, actual), + ) + ) return res # Now the potential subtype is known not to be a Union or a type @@ -410,10 +451,30 @@ def _infer_constraints( # When the template is a union, we are okay with leaving some # type variables indeterminate. This helps with some special # cases, though this isn't very principled. + + def _is_item_overlapping_actual_type(_item: Type) -> bool: + # Overlapping occurs when we have a Union where two types are + # compatible and the more generic one is chosen. + # For example, in Union[T, Sequence[T]], we have to choose + # Sequence[T] if actual type is list[int]. + # This returns true if the item is an argument of other item + # that is subtype of the actual type + return any( + isinstance(p_item_to_compare := get_proper_type(item_to_compare), Instance) + and mypy.subtypes.is_subtype(actual, erase_typevars(p_item_to_compare)) + and _item in p_item_to_compare.args + for item_to_compare in template.items + if _item is not item_to_compare + ) + result = any_constraints( [ infer_constraints_if_possible(t_item, actual, direction) - for t_item in template.items + for t_item in [ + item + for item in template.items + if not (can_have_union_overlapping and _is_item_overlapping_actual_type(item)) + ] ], eager=False, ) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index cb0b11bf013c..0dc4d8ec5840 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -873,13 +873,7 @@ def g(x: Union[T, List[T]]) -> List[T]: pass def h(x: List[str]) -> None: pass g('a')() # E: "List[str]" not callable -# The next line is a case where there are multiple ways to satisfy a constraint -# involving a Union. Either T = List[str] or T = str would turn out to be valid, -# but mypy doesn't know how to branch on these two options (and potentially have -# to backtrack later) and defaults to T = Never. The result is an -# awkward error message. Either a better error message, or simply accepting the -# call, would be preferable here. -g(['a']) # E: Argument 1 to "g" has incompatible type "List[str]"; expected "List[Never]" +g(['a']) h(g(['a'])) @@ -891,6 +885,53 @@ i(b, a, b) i(a, b, b) # E: Argument 1 to "i" has incompatible type "List[int]"; expected "List[str]" [builtins fixtures/list.pyi] +[case testInferenceOfUnionOfGenericClassAndItsGenericType] +from typing import Generic, TypeVar, Union + +T = TypeVar('T') + +class GenericClass(Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + +def method_with_union(arg: Union[GenericClass[T], T]) -> GenericClass[T]: + if not isinstance(arg, GenericClass): + arg = GenericClass(arg) + return arg + +result_1 = method_with_union(GenericClass("test")) +reveal_type(result_1) # N: Revealed type is "__main__.GenericClass[builtins.str]" + +result_2 = method_with_union("test") +reveal_type(result_2) # N: Revealed type is "__main__.GenericClass[builtins.str]" + +[builtins fixtures/isinstance.pyi] + +[case testInferenceOfUnionOfSequenceOfAnyAndSequenceOfSequence] +from typing import Sequence, Iterable, TypeVar, Union + +T = TypeVar("T") +S = TypeVar("S") + +def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]: + pass + +def method(value: Union[Sequence[T], Sequence[Sequence[T]]]) -> None: + reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[typing.Sequence[T`-1]]" + +[case testInferenceOfUnionOfUnionWithSequenceAndSequenceOfThatUnion] +from typing import Sequence, Iterable, TypeVar, Union + +T = Union[str, Sequence[int]] +S = TypeVar("S", bound=T) + +def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]: + pass + +def method(value: Union[T, Sequence[T]]) -> None: + reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[Union[builtins.str, typing.Sequence[builtins.int]]]" + + [case testCallableListJoinInference] from typing import Any, Callable