Skip to content

Commit 5c33abf

Browse files
committed
Further improvements to functools.partial handling (#17425)
- Fixes another crash case / type inference in that case - Fix a false positive when calling the partially applied function with kwargs - TypeTraverse / comment / daemon test follow up ilevkivskyi mentioned on the original PR See also #17423
1 parent c37d972 commit 5c33abf

File tree

5 files changed

+169
-35
lines changed

5 files changed

+169
-35
lines changed

mypy/plugins/functools.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,14 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
245245
partial_kinds.append(fn_type.arg_kinds[i])
246246
partial_types.append(arg_type)
247247
partial_names.append(fn_type.arg_names[i])
248-
elif actuals:
249-
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
248+
else:
249+
assert actuals
250+
if any(actual_arg_kinds[j] in (ArgKind.ARG_POS, ArgKind.ARG_STAR) for j in actuals):
251+
# Don't add params for arguments passed positionally
250252
continue
253+
# Add defaulted params for arguments passed via keyword
251254
kind = actual_arg_kinds[actuals[0]]
252-
if kind == ArgKind.ARG_NAMED:
255+
if kind == ArgKind.ARG_NAMED or kind == ArgKind.ARG_STAR2:
253256
kind = ArgKind.ARG_NAMED_OPT
254257
partial_kinds.append(kind)
255258
partial_types.append(arg_type)
@@ -286,15 +289,25 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
286289
if len(ctx.arg_types) != 2: # *args, **kwargs
287290
return ctx.default_return_type
288291

289-
args = [a for param in ctx.args for a in param]
290-
arg_kinds = [a for param in ctx.arg_kinds for a in param]
291-
arg_names = [a for param in ctx.arg_names for a in param]
292+
# See comments for similar actual to formal code above
293+
actual_args = []
294+
actual_arg_kinds = []
295+
actual_arg_names = []
296+
seen_args = set()
297+
for i, param in enumerate(ctx.args):
298+
for j, a in enumerate(param):
299+
if a in seen_args:
300+
continue
301+
seen_args.add(a)
302+
actual_args.append(a)
303+
actual_arg_kinds.append(ctx.arg_kinds[i][j])
304+
actual_arg_names.append(ctx.arg_names[i][j])
292305

293306
result = ctx.api.expr_checker.check_call(
294307
callee=partial_type,
295-
args=args,
296-
arg_kinds=arg_kinds,
297-
arg_names=arg_names,
308+
args=actual_args,
309+
arg_kinds=actual_arg_kinds,
310+
arg_names=actual_arg_names,
298311
context=ctx.context,
299312
)
300313
return result[0]

mypy/type_visitor.py

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def visit_instance(self, t: Instance) -> Type:
213213
line=t.line,
214214
column=t.column,
215215
last_known_value=last_known_value,
216+
extra_attrs=t.extra_attrs,
216217
)
217218

218219
def visit_type_var(self, t: TypeVarType) -> Type:

mypy/types.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1417,8 +1417,7 @@ def __init__(
14171417
self._hash = -1
14181418

14191419
# Additional attributes defined per instance of this type. For example modules
1420-
# have different attributes per instance of types.ModuleType. This is intended
1421-
# to be "short-lived", we don't serialize it, and even don't store as variable type.
1420+
# have different attributes per instance of types.ModuleType.
14221421
self.extra_attrs = extra_attrs
14231422

14241423
def accept(self, visitor: TypeVisitor[T]) -> T:

test-data/unit/check-functools.test

+97-24
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ functools.partial(1) # E: "int" not callable \
191191

192192
[case testFunctoolsPartialStar]
193193
import functools
194+
from typing import List
194195

195196
def foo(a: int, b: str, *args: int, d: str, **kwargs: int) -> int: ...
196197

@@ -215,6 +216,13 @@ def bar(*a: bytes, **k: int):
215216
p1("a", **k) # E: Argument 2 to "foo" has incompatible type "**Dict[str, int]"; expected "str"
216217
p1(**k) # E: Argument 1 to "foo" has incompatible type "**Dict[str, int]"; expected "str"
217218
p1(*a) # E: List or tuple expected as variadic arguments
219+
220+
221+
def baz(a: int, b: int) -> int: ...
222+
def test_baz(xs: List[int]):
223+
p3 = functools.partial(baz, *xs)
224+
p3()
225+
p3(1) # E: Too many arguments for "baz"
218226
[builtins fixtures/dict.pyi]
219227

220228
[case testFunctoolsPartialGeneric]
@@ -408,33 +416,83 @@ def foo(cls3: Type[B[T]]):
408416
from typing_extensions import TypedDict, Unpack
409417
from functools import partial
410418

411-
class Data(TypedDict, total=False):
412-
x: int
413-
414-
def f(**kwargs: Unpack[Data]) -> None: ...
415-
def g(**kwargs: Unpack[Data]) -> None:
416-
partial(f, **kwargs)()
417-
418-
class MoreData(TypedDict, total=False):
419-
x: int
420-
y: int
419+
class D1(TypedDict, total=False):
420+
a1: int
421+
422+
def fn1(a1: int) -> None: ... # N: "fn1" defined here
423+
def main1(**d1: Unpack[D1]) -> None:
424+
partial(fn1, **d1)()
425+
partial(fn1, **d1)(**d1)
426+
partial(fn1, **d1)(a1=1)
427+
partial(fn1, **d1)(a1="asdf") # E: Argument "a1" to "fn1" has incompatible type "str"; expected "int"
428+
partial(fn1, **d1)(oops=1) # E: Unexpected keyword argument "oops" for "fn1"
429+
430+
def fn2(**kwargs: Unpack[D1]) -> None: ... # N: "fn2" defined here
431+
def main2(**d1: Unpack[D1]) -> None:
432+
partial(fn2, **d1)()
433+
partial(fn2, **d1)(**d1)
434+
partial(fn2, **d1)(a1=1)
435+
partial(fn2, **d1)(a1="asdf") # E: Argument "a1" to "fn2" has incompatible type "str"; expected "int"
436+
partial(fn2, **d1)(oops=1) # E: Unexpected keyword argument "oops" for "fn2"
437+
438+
class D2(TypedDict, total=False):
439+
a1: int
440+
a2: str
441+
442+
class A2Good(TypedDict, total=False):
443+
a2: str
444+
class A2Bad(TypedDict, total=False):
445+
a2: int
446+
447+
def fn3(a1: int, a2: str) -> None: ... # N: "fn3" defined here
448+
def main3(a2good: A2Good, a2bad: A2Bad, **d2: Unpack[D2]) -> None:
449+
partial(fn3, **d2)()
450+
partial(fn3, **d2)(a1=1, a2="asdf")
451+
452+
partial(fn3, **d2)(**d2)
453+
454+
partial(fn3, **d2)(a1="asdf") # E: Argument "a1" to "fn3" has incompatible type "str"; expected "int"
455+
partial(fn3, **d2)(a1=1, a2="asdf", oops=1) # E: Unexpected keyword argument "oops" for "fn3"
456+
457+
partial(fn3, **d2)(**a2good)
458+
partial(fn3, **d2)(**a2bad) # E: Argument "a2" to "fn3" has incompatible type "int"; expected "str"
459+
460+
def fn4(**kwargs: Unpack[D2]) -> None: ... # N: "fn4" defined here
461+
def main4(a2good: A2Good, a2bad: A2Bad, **d2: Unpack[D2]) -> None:
462+
partial(fn4, **d2)()
463+
partial(fn4, **d2)(a1=1, a2="asdf")
464+
465+
partial(fn4, **d2)(**d2)
466+
467+
partial(fn4, **d2)(a1="asdf") # E: Argument "a1" to "fn4" has incompatible type "str"; expected "int"
468+
partial(fn4, **d2)(a1=1, a2="asdf", oops=1) # E: Unexpected keyword argument "oops" for "fn4"
469+
470+
partial(fn3, **d2)(**a2good)
471+
partial(fn3, **d2)(**a2bad) # E: Argument "a2" to "fn3" has incompatible type "int"; expected "str"
472+
473+
def main5(**d2: Unpack[D2]) -> None:
474+
partial(fn1, **d2)() # E: Extra argument "a2" from **args for "fn1"
475+
partial(fn2, **d2)() # E: Extra argument "a2" from **args for "fn2"
476+
477+
def main6(a2good: A2Good, a2bad: A2Bad, **d1: Unpack[D1]) -> None:
478+
partial(fn3, **d1)() # E: Missing positional argument "a1" in call to "fn3"
479+
partial(fn3, **d1)("asdf") # E: Too many positional arguments for "fn3" \
480+
# E: Too few arguments for "fn3" \
481+
# E: Argument 1 to "fn3" has incompatible type "str"; expected "int"
482+
partial(fn3, **d1)(a2="asdf")
483+
partial(fn3, **d1)(**a2good)
484+
partial(fn3, **d1)(**a2bad) # E: Argument "a2" to "fn3" has incompatible type "int"; expected "str"
485+
486+
partial(fn4, **d1)()
487+
partial(fn4, **d1)("asdf") # E: Too many positional arguments for "fn4" \
488+
# E: Argument 1 to "fn4" has incompatible type "str"; expected "int"
489+
partial(fn4, **d1)(a2="asdf")
490+
partial(fn4, **d1)(**a2good)
491+
partial(fn4, **d1)(**a2bad) # E: Argument "a2" to "fn4" has incompatible type "int"; expected "str"
421492

422-
def f_more(**kwargs: Unpack[MoreData]) -> None: ...
423-
def g_more(**kwargs: Unpack[MoreData]) -> None:
424-
partial(f_more, **kwargs)()
425-
426-
class Good(TypedDict, total=False):
427-
y: int
428-
class Bad(TypedDict, total=False):
429-
y: str
430-
431-
def h(**kwargs: Unpack[Data]) -> None:
432-
bad: Bad
433-
partial(f_more, **kwargs)(**bad) # E: Argument "y" to "f_more" has incompatible type "str"; expected "int"
434-
good: Good
435-
partial(f_more, **kwargs)(**good)
436493
[builtins fixtures/dict.pyi]
437494

495+
438496
[case testFunctoolsPartialNestedGeneric]
439497
from functools import partial
440498
from typing import Generic, TypeVar, List
@@ -456,6 +514,21 @@ first_kw([1]) # E: Too many positional arguments for "get" \
456514
# E: Argument 1 to "get" has incompatible type "List[int]"; expected "int"
457515
[builtins fixtures/list.pyi]
458516

517+
[case testFunctoolsPartialHigherOrder]
518+
from functools import partial
519+
from typing import Callable
520+
521+
def fn(a: int, b: str, c: bytes) -> int: ...
522+
523+
def callback1(fn: Callable[[str, bytes], int]) -> None: ...
524+
def callback2(fn: Callable[[str, int], int]) -> None: ...
525+
526+
callback1(partial(fn, 1))
527+
# TODO: false negative
528+
# https://github.com/python/mypy/issues/17461
529+
callback2(partial(fn, 1))
530+
[builtins fixtures/tuple.pyi]
531+
459532
[case testFunctoolsPartialClassObjectMatchingPartial]
460533
from functools import partial
461534

test-data/unit/fine-grained.test

+48
Original file line numberDiff line numberDiff line change
@@ -10497,3 +10497,51 @@ from pkg.sub import modb
1049710497

1049810498
[out]
1049910499
==
10500+
10501+
[case testFineGrainedFunctoolsPartial]
10502+
import m
10503+
10504+
[file m.py]
10505+
from typing import Callable
10506+
from partial import p1
10507+
10508+
reveal_type(p1)
10509+
p1("a")
10510+
p1("a", 3)
10511+
p1("a", c=3)
10512+
p1(1, 3)
10513+
p1(1, "a", 3)
10514+
p1(a=1, b="a", c=3)
10515+
[builtins fixtures/dict.pyi]
10516+
10517+
[file partial.py]
10518+
from typing import Callable
10519+
import functools
10520+
10521+
def foo(a: int, b: str, c: int = 5) -> int: ...
10522+
p1 = foo
10523+
10524+
[file partial.py.2]
10525+
from typing import Callable
10526+
import functools
10527+
10528+
def foo(a: int, b: str, c: int = 5) -> int: ...
10529+
p1 = functools.partial(foo, 1)
10530+
10531+
[out]
10532+
m.py:4: note: Revealed type is "def (a: builtins.int, b: builtins.str, c: builtins.int =) -> builtins.int"
10533+
m.py:5: error: Too few arguments
10534+
m.py:5: error: Argument 1 has incompatible type "str"; expected "int"
10535+
m.py:6: error: Argument 1 has incompatible type "str"; expected "int"
10536+
m.py:6: error: Argument 2 has incompatible type "int"; expected "str"
10537+
m.py:7: error: Too few arguments
10538+
m.py:7: error: Argument 1 has incompatible type "str"; expected "int"
10539+
m.py:8: error: Argument 2 has incompatible type "int"; expected "str"
10540+
==
10541+
m.py:4: note: Revealed type is "functools.partial[builtins.int]"
10542+
m.py:8: error: Argument 1 to "foo" has incompatible type "int"; expected "str"
10543+
m.py:9: error: Too many arguments for "foo"
10544+
m.py:9: error: Argument 1 to "foo" has incompatible type "int"; expected "str"
10545+
m.py:9: error: Argument 2 to "foo" has incompatible type "str"; expected "int"
10546+
m.py:10: error: Unexpected keyword argument "a" for "foo"
10547+
partial.py:4: note: "foo" defined here

0 commit comments

Comments
 (0)