Skip to content

Fix union inference of generic class and its generic type #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,11 @@ def infer_constraints_for_callable(


def infer_constraints(
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
template: Type,
actual: Type,
direction: int,
skip_neg_op: bool = False,
can_have_union_overlapping: bool = True,
) -> list[Constraint]:
"""Infer type constraints.

Expand Down Expand Up @@ -316,11 +320,15 @@ def infer_constraints(
res = _infer_constraints(template, actual, direction, skip_neg_op)
type_state.inferring.pop()
return res
return _infer_constraints(template, actual, direction, skip_neg_op)
return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlapping)


def _infer_constraints(
template: Type, actual: Type, direction: int, skip_neg_op: bool
template: Type,
actual: Type,
direction: int,
skip_neg_op: bool,
can_have_union_overlapping: bool = True,
) -> list[Constraint]:
orig_template = template
template = get_proper_type(template)
Expand Down Expand Up @@ -383,13 +391,46 @@ def _infer_constraints(
return res
if direction == SUPERTYPE_OF and isinstance(actual, UnionType):
res = []

def _can_have_overlapping(_item: Type, _actual: UnionType) -> bool:
# There is a special overlapping case, where we have a Union of where two types
# are the same, but one of them contains the other.
# For example, we have Union[Sequence[T], Sequence[Sequence[T]]]
# In this case, only the second one can have overlapping because it contains the other.
# So, in case of list[list[int]], second one would be chosen.
if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args:
other_items = [o_item for o_item in _actual.items if o_item is not a_item]

if len(other_items) == 1 and other_items[0] in p_item.args:
return True

Comment on lines +395 to +406
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

NameError: undefined variable a_item in _can_have_overlapping

a_item is not defined in this scope – the comprehension intends to compare each
o_item with the _item currently being inspected.

-                other_items = [o_item for o_item in _actual.items if o_item is not a_item]
+                other_items = [o_item for o_item in _actual.items if o_item is not _item]

Without this fix the helper will raise at runtime and break constraint inference
whenever SUPERTYPE_OF union handling is triggered.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _can_have_overlapping(_item: Type, _actual: UnionType) -> bool:
# There is a special overlapping case, where we have a Union of where two types
# are the same, but one of them contains the other.
# For example, we have Union[Sequence[T], Sequence[Sequence[T]]]
# In this case, only the second one can have overlapping because it contains the other.
# So, in case of list[list[int]], second one would be chosen.
if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args:
other_items = [o_item for o_item in _actual.items if o_item is not a_item]
if len(other_items) == 1 and other_items[0] in p_item.args:
return True
def _can_have_overlapping(_item: Type, _actual: UnionType) -> bool:
# There is a special overlapping case, where we have a Union of where two types
# are the same, but one of them contains the other.
# For example, we have Union[Sequence[T], Sequence[Sequence[T]]]
# In this case, only the second one can have overlapping because it contains the other.
# So, in case of list[list[int]], second one would be chosen.
if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args:
- other_items = [o_item for o_item in _actual.items if o_item is not a_item]
+ other_items = [o_item for o_item in _actual.items if o_item is not _item]
if len(other_items) == 1 and other_items[0] in p_item.args:
return True

if len(other_items) > 1:
union_args = [
p_arg
for arg in p_item.args
if isinstance(p_arg := get_proper_type(arg), UnionType)
]

for union_arg in union_args:
if all(o_item in union_arg.items for o_item in other_items):
return True

return False

for a_item in actual.items:
# `orig_template` has to be preserved intact in case it's recursive.
# If we unwrapped ``type[...]`` previously, wrap the item back again,
# as ``type[...]`` can't be removed from `orig_template`.
if type_type_unwrapped:
a_item = TypeType.make_normalized(a_item)
res.extend(infer_constraints(orig_template, a_item, direction))
res.extend(
infer_constraints(
orig_template,
a_item,
direction,
can_have_union_overlapping=_can_have_overlapping(a_item, actual),
)
)
return res

# Now the potential subtype is known not to be a Union or a type
Expand All @@ -410,10 +451,30 @@ def _infer_constraints(
# When the template is a union, we are okay with leaving some
# type variables indeterminate. This helps with some special
# cases, though this isn't very principled.

def _is_item_overlapping_actual_type(_item: Type) -> bool:
# Overlapping occurs when we have a Union where two types are
# compatible and the more generic one is chosen.
# For example, in Union[T, Sequence[T]], we have to choose
# Sequence[T] if actual type is list[int].
# This returns true if the item is an argument of other item
# that is subtype of the actual type
return any(
isinstance(p_item_to_compare := get_proper_type(item_to_compare), Instance)
and mypy.subtypes.is_subtype(actual, erase_typevars(p_item_to_compare))
and _item in p_item_to_compare.args
for item_to_compare in template.items
if _item is not item_to_compare
)

result = any_constraints(
[
infer_constraints_if_possible(t_item, actual, direction)
for t_item in template.items
for t_item in [
item
for item in template.items
if not (can_have_union_overlapping and _is_item_overlapping_actual_type(item))
]
],
eager=False,
)
Expand Down
55 changes: 48 additions & 7 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -873,13 +873,7 @@ def g(x: Union[T, List[T]]) -> List[T]: pass
def h(x: List[str]) -> None: pass
g('a')() # E: "List[str]" not callable

# The next line is a case where there are multiple ways to satisfy a constraint
# involving a Union. Either T = List[str] or T = str would turn out to be valid,
# but mypy doesn't know how to branch on these two options (and potentially have
# to backtrack later) and defaults to T = Never. The result is an
# awkward error message. Either a better error message, or simply accepting the
# call, would be preferable here.
g(['a']) # E: Argument 1 to "g" has incompatible type "List[str]"; expected "List[Never]"
g(['a'])

h(g(['a']))

Expand All @@ -891,6 +885,53 @@ i(b, a, b)
i(a, b, b) # E: Argument 1 to "i" has incompatible type "List[int]"; expected "List[str]"
[builtins fixtures/list.pyi]

[case testInferenceOfUnionOfGenericClassAndItsGenericType]
from typing import Generic, TypeVar, Union

T = TypeVar('T')

class GenericClass(Generic[T]):
def __init__(self, value: T) -> None:
self.value = value

def method_with_union(arg: Union[GenericClass[T], T]) -> GenericClass[T]:
if not isinstance(arg, GenericClass):
arg = GenericClass(arg)
return arg

result_1 = method_with_union(GenericClass("test"))
reveal_type(result_1) # N: Revealed type is "__main__.GenericClass[builtins.str]"

result_2 = method_with_union("test")
reveal_type(result_2) # N: Revealed type is "__main__.GenericClass[builtins.str]"

[builtins fixtures/isinstance.pyi]

[case testInferenceOfUnionOfSequenceOfAnyAndSequenceOfSequence]
from typing import Sequence, Iterable, TypeVar, Union

T = TypeVar("T")
S = TypeVar("S")

def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]:
pass

def method(value: Union[Sequence[T], Sequence[Sequence[T]]]) -> None:
reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[typing.Sequence[T`-1]]"

[case testInferenceOfUnionOfUnionWithSequenceAndSequenceOfThatUnion]
from typing import Sequence, Iterable, TypeVar, Union

T = Union[str, Sequence[int]]
S = TypeVar("S", bound=T)

def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]:
pass

def method(value: Union[T, Sequence[T]]) -> None:
reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[Union[builtins.str, typing.Sequence[builtins.int]]]"


[case testCallableListJoinInference]
from typing import Any, Callable

Expand Down