Skip to content

Commit 5a19bfb

Browse files
sobolevnJukkaL
authored andcommitted
Properly check *CustomType and **CustomType arguments (#11151)
1 parent f08d72b commit 5a19bfb

15 files changed

+336
-62
lines changed

mypy/argmap.py

+31-14
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""Utilities for mapping between actual and formal arguments (and their types)."""
22

3-
from typing import List, Optional, Sequence, Callable, Set
3+
from typing import TYPE_CHECKING, List, Optional, Sequence, Callable, Set
44

5+
from mypy.maptype import map_instance_to_supertype
56
from mypy.types import (
67
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, get_proper_type
78
)
89
from mypy import nodes
910

11+
if TYPE_CHECKING:
12+
from mypy.infer import ArgumentInferContext
13+
1014

1115
def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
1216
actual_names: Optional[Sequence[Optional[str]]],
@@ -140,11 +144,13 @@ def f(x: int, *args: str) -> None: ...
140144
needs a separate instance since instances have per-call state.
141145
"""
142146

143-
def __init__(self) -> None:
147+
def __init__(self, context: 'ArgumentInferContext') -> None:
144148
# Next tuple *args index to use.
145149
self.tuple_index = 0
146150
# Keyword arguments in TypedDict **kwargs used.
147151
self.kwargs_used: Set[str] = set()
152+
# Type context for `*` and `**` arg kinds.
153+
self.context = context
148154

149155
def expand_actual_type(self,
150156
actual_type: Type,
@@ -162,16 +168,21 @@ def expand_actual_type(self,
162168
This is supposed to be called for each formal, in order. Call multiple times per
163169
formal if multiple actuals map to a formal.
164170
"""
171+
from mypy.subtypes import is_subtype
172+
165173
actual_type = get_proper_type(actual_type)
166174
if actual_kind == nodes.ARG_STAR:
167-
if isinstance(actual_type, Instance):
168-
if actual_type.type.fullname == 'builtins.list':
169-
# List *arg.
170-
return actual_type.args[0]
171-
elif actual_type.args:
172-
# TODO: Try to map type arguments to Iterable
173-
return actual_type.args[0]
175+
if isinstance(actual_type, Instance) and actual_type.args:
176+
if is_subtype(actual_type, self.context.iterable_type):
177+
return map_instance_to_supertype(
178+
actual_type,
179+
self.context.iterable_type.type,
180+
).args[0]
174181
else:
182+
# We cannot properly unpack anything other
183+
# than `Iterable` type with `*`.
184+
# Just return `Any`, other parts of code would raise
185+
# a different error for improper use.
175186
return AnyType(TypeOfAny.from_error)
176187
elif isinstance(actual_type, TupleType):
177188
# Get the next tuple item of a tuple *arg.
@@ -193,11 +204,17 @@ def expand_actual_type(self,
193204
formal_name = (set(actual_type.items.keys()) - self.kwargs_used).pop()
194205
self.kwargs_used.add(formal_name)
195206
return actual_type.items[formal_name]
196-
elif (isinstance(actual_type, Instance)
197-
and (actual_type.type.fullname == 'builtins.dict')):
198-
# Dict **arg.
199-
# TODO: Handle arbitrary Mapping
200-
return actual_type.args[1]
207+
elif (
208+
isinstance(actual_type, Instance) and
209+
len(actual_type.args) > 1 and
210+
is_subtype(actual_type, self.context.mapping_type)
211+
):
212+
# Only `Mapping` type can be unpacked with `**`.
213+
# Other types will produce an error somewhere else.
214+
return map_instance_to_supertype(
215+
actual_type,
216+
self.context.mapping_type.type,
217+
).args[1]
201218
else:
202219
return AnyType(TypeOfAny.from_error)
203220
else:

mypy/checkexpr.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
from mypy.maptype import map_instance_to_supertype
4646
from mypy.messages import MessageBuilder
4747
from mypy import message_registry
48-
from mypy.infer import infer_type_arguments, infer_function_type_arguments
48+
from mypy.infer import (
49+
ArgumentInferContext, infer_type_arguments, infer_function_type_arguments,
50+
)
4951
from mypy import join
5052
from mypy.meet import narrow_declared_type, is_overlapping_types
5153
from mypy.subtypes import is_subtype, is_proper_subtype, is_equivalent, non_method_protocol_members
@@ -1235,6 +1237,7 @@ def infer_function_type_arguments(self, callee_type: CallableType,
12351237

12361238
inferred_args = infer_function_type_arguments(
12371239
callee_type, pass1_args, arg_kinds, formal_to_actual,
1240+
context=self.argument_infer_context(),
12381241
strict=self.chk.in_checked_function())
12391242

12401243
if 2 in arg_pass_nums:
@@ -1296,10 +1299,18 @@ def infer_function_type_arguments_pass2(
12961299
callee_type, args, arg_kinds, formal_to_actual)
12971300

12981301
inferred_args = infer_function_type_arguments(
1299-
callee_type, arg_types, arg_kinds, formal_to_actual)
1302+
callee_type, arg_types, arg_kinds, formal_to_actual,
1303+
context=self.argument_infer_context(),
1304+
)
13001305

13011306
return callee_type, inferred_args
13021307

1308+
def argument_infer_context(self) -> ArgumentInferContext:
1309+
return ArgumentInferContext(
1310+
self.chk.named_type('typing.Mapping'),
1311+
self.chk.named_type('typing.Iterable'),
1312+
)
1313+
13031314
def get_arg_infer_passes(self, arg_types: List[Type],
13041315
formal_to_actual: List[List[int]],
13051316
num_actuals: int) -> List[int]:
@@ -1474,7 +1485,7 @@ def check_argument_types(self,
14741485
messages = messages or self.msg
14751486
check_arg = check_arg or self.check_arg
14761487
# Keep track of consumed tuple *arg items.
1477-
mapper = ArgTypeExpander()
1488+
mapper = ArgTypeExpander(self.argument_infer_context())
14781489
for i, actuals in enumerate(formal_to_actual):
14791490
for actual in actuals:
14801491
actual_type = arg_types[actual]

mypy/constraints.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Type inference constraints."""
22

3-
from typing import Iterable, List, Optional, Sequence
3+
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
44
from typing_extensions import Final
55

66
from mypy.types import (
@@ -18,6 +18,9 @@
1818
from mypy.argmap import ArgTypeExpander
1919
from mypy.typestate import TypeState
2020

21+
if TYPE_CHECKING:
22+
from mypy.infer import ArgumentInferContext
23+
2124
SUBTYPE_OF: Final = 0
2225
SUPERTYPE_OF: Final = 1
2326

@@ -45,14 +48,17 @@ def __repr__(self) -> str:
4548

4649

4750
def infer_constraints_for_callable(
48-
callee: CallableType, arg_types: Sequence[Optional[Type]], arg_kinds: List[ArgKind],
49-
formal_to_actual: List[List[int]]) -> List[Constraint]:
51+
callee: CallableType,
52+
arg_types: Sequence[Optional[Type]],
53+
arg_kinds: List[ArgKind],
54+
formal_to_actual: List[List[int]],
55+
context: 'ArgumentInferContext') -> List[Constraint]:
5056
"""Infer type variable constraints for a callable and actual arguments.
5157
5258
Return a list of constraints.
5359
"""
5460
constraints: List[Constraint] = []
55-
mapper = ArgTypeExpander()
61+
mapper = ArgTypeExpander(context)
5662

5763
for i, actuals in enumerate(formal_to_actual):
5864
for actual in actuals:

mypy/infer.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,34 @@
11
"""Utilities for type argument inference."""
22

3-
from typing import List, Optional, Sequence
3+
from typing import List, Optional, Sequence, NamedTuple
44

55
from mypy.constraints import (
66
infer_constraints, infer_constraints_for_callable, SUBTYPE_OF, SUPERTYPE_OF
77
)
8-
from mypy.types import Type, TypeVarId, CallableType
8+
from mypy.types import Type, TypeVarId, CallableType, Instance
99
from mypy.nodes import ArgKind
1010
from mypy.solve import solve_constraints
1111

1212

13+
class ArgumentInferContext(NamedTuple):
14+
"""Type argument inference context.
15+
16+
We need this because we pass around ``Mapping`` and ``Iterable`` types.
17+
These types are only known by ``TypeChecker`` itself.
18+
It is required for ``*`` and ``**`` argument inference.
19+
20+
https://github.com/python/mypy/issues/11144
21+
"""
22+
23+
mapping_type: Instance
24+
iterable_type: Instance
25+
26+
1327
def infer_function_type_arguments(callee_type: CallableType,
1428
arg_types: Sequence[Optional[Type]],
1529
arg_kinds: List[ArgKind],
1630
formal_to_actual: List[List[int]],
31+
context: ArgumentInferContext,
1732
strict: bool = True) -> List[Optional[Type]]:
1833
"""Infer the type arguments of a generic function.
1934
@@ -30,7 +45,7 @@ def infer_function_type_arguments(callee_type: CallableType,
3045
"""
3146
# Infer constraints.
3247
constraints = infer_constraints_for_callable(
33-
callee_type, arg_types, arg_kinds, formal_to_actual)
48+
callee_type, arg_types, arg_kinds, formal_to_actual, context)
3449

3550
# Solve constraints.
3651
type_vars = callee_type.type_var_ids()

test-data/unit/check-expressions.test

+4-31
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,7 @@ if str():
6363
a = 1.1
6464
class A:
6565
pass
66-
[file builtins.py]
67-
class object:
68-
def __init__(self): pass
69-
class type: pass
70-
class function: pass
71-
class float: pass
72-
class str: pass
66+
[builtins fixtures/dict.pyi]
7367

7468
[case testComplexLiteral]
7569
a = 0.0j
@@ -80,13 +74,7 @@ if str():
8074
a = 1.1j
8175
class A:
8276
pass
83-
[file builtins.py]
84-
class object:
85-
def __init__(self): pass
86-
class type: pass
87-
class function: pass
88-
class complex: pass
89-
class str: pass
77+
[builtins fixtures/dict.pyi]
9078

9179
[case testBytesLiteral]
9280
b, a = None, None # type: (bytes, A)
@@ -99,14 +87,7 @@ if str():
9987
if str():
10088
a = b'foo' # E: Incompatible types in assignment (expression has type "bytes", variable has type "A")
10189
class A: pass
102-
[file builtins.py]
103-
class object:
104-
def __init__(self): pass
105-
class type: pass
106-
class tuple: pass
107-
class function: pass
108-
class bytes: pass
109-
class str: pass
90+
[builtins fixtures/dict.pyi]
11091

11192
[case testUnicodeLiteralInPython3]
11293
s = None # type: str
@@ -2126,15 +2107,7 @@ if str():
21262107
....a # E: "ellipsis" has no attribute "a"
21272108

21282109
class A: pass
2129-
[file builtins.py]
2130-
class object:
2131-
def __init__(self): pass
2132-
class ellipsis:
2133-
def __init__(self): pass
2134-
__class__ = object()
2135-
class type: pass
2136-
class function: pass
2137-
class str: pass
2110+
[builtins fixtures/dict.pyi]
21382111
[out]
21392112

21402113

0 commit comments

Comments
 (0)