Skip to content

Commit fcd2010

Browse files
validate_call type params fix (#9760)
1 parent 04f3a46 commit fcd2010

File tree

4 files changed

+80
-5
lines changed

4 files changed

+80
-5
lines changed

pydantic/_internal/_typing_extra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def get_function_type_hints(
307307

308308
globalns = add_module_globals(function)
309309
type_hints = {}
310+
type_params: tuple[Any] = getattr(function, '__type_params__', ()) # type: ignore
310311
for name, value in annotations.items():
311312
if include_keys is not None and name not in include_keys:
312313
continue
@@ -315,7 +316,7 @@ def get_function_type_hints(
315316
elif isinstance(value, str):
316317
value = _make_forward_ref(value)
317318

318-
type_hints[name] = eval_type_backport(value, globalns, types_namespace)
319+
type_hints[name] = eval_type_backport(value, globalns, types_namespace, type_params)
319320

320321
return type_hints
321322

pydantic/_internal/_validate_call.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ class ValidateCallWrapper:
2323
'__dict__', # required for __module__
2424
)
2525

26-
def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
26+
def __init__(
27+
self,
28+
function: Callable[..., Any],
29+
config: ConfigDict | None,
30+
validate_return: bool,
31+
namespace: dict[str, Any] | None,
32+
):
2733
if isinstance(function, partial):
2834
func = function.func
2935
schema_type = func
@@ -36,7 +42,16 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
3642
self.__qualname__ = function.__qualname__
3743
self.__module__ = function.__module__
3844

39-
namespace = _typing_extra.add_module_globals(function, None)
45+
global_ns = _typing_extra.add_module_globals(function, None)
46+
# TODO: this is a bit of a hack, we should probably have a better way to handle this
47+
# specifically, we shouldn't be pumping the namespace full of type_params
48+
# when we take namespace and type_params arguments in eval_type_backport
49+
type_params = getattr(schema_type, '__type_params__', ())
50+
namespace = {
51+
**{param.__name__: param for param in type_params},
52+
**(global_ns or {}),
53+
**(namespace or {}),
54+
}
4055
config_wrapper = ConfigWrapper(config)
4156
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
4257
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))

pydantic/validate_call_decorator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import functools
66
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
77

8-
from ._internal import _validate_call
8+
from ._internal import _typing_extra, _validate_call
99

1010
__all__ = ('validate_call',)
1111

@@ -46,12 +46,14 @@ def validate_call(
4646
Returns:
4747
The decorated function.
4848
"""
49+
local_ns = _typing_extra.parent_frame_namespace()
4950

5051
def validate(function: AnyCallableT) -> AnyCallableT:
5152
if isinstance(function, (classmethod, staticmethod)):
5253
name = type(function).__name__
5354
raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)')
54-
validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return)
55+
56+
validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return, local_ns)
5557

5658
@functools.wraps(function)
5759
def wrapper_function(*args, **kwargs):

tests/test_validate_call.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import inspect
33
import re
4+
import sys
45
from datetime import datetime, timezone
56
from functools import partial
67
from typing import Any, List, Tuple
@@ -803,3 +804,59 @@ def foo(bar: 'list[int | str]') -> 'list[int | str]':
803804
'input': {'not a str or int'},
804805
},
805806
]
807+
808+
809+
@pytest.mark.skipif(sys.version_info < (3, 12), reason='requires Python 3.12+ for PEP 695 syntax with generics')
810+
def test_validate_call_with_pep_695_syntax() -> None:
811+
"""Note: validate_call still doesn't work properly with generics, see https://github.com/pydantic/pydantic/issues/7796.
812+
813+
This test is just to ensure that the syntax is accepted and doesn't raise a NameError."""
814+
globs = {}
815+
exec(
816+
"""
817+
from typing import Iterable
818+
from pydantic import validate_call
819+
820+
@validate_call
821+
def find_max_no_validate_return[T](args: Iterable[T]) -> T:
822+
return sorted(args, reverse=True)[0]
823+
824+
@validate_call(validate_return=True)
825+
def find_max_validate_return[T](args: Iterable[T]) -> T:
826+
return sorted(args, reverse=True)[0]
827+
""",
828+
globs,
829+
)
830+
functions = [globs['find_max_no_validate_return'], globs['find_max_validate_return']]
831+
for find_max in functions:
832+
assert len(find_max.__type_params__) == 1
833+
assert find_max([1, 2, 10, 5]) == 10
834+
835+
with pytest.raises(ValidationError):
836+
find_max(1)
837+
838+
839+
class M0(BaseModel):
840+
z: int
841+
842+
843+
M = M0
844+
845+
846+
def test_uses_local_ns():
847+
class M1(BaseModel):
848+
y: int
849+
850+
M = M1 # noqa: F841
851+
852+
def foo():
853+
class M2(BaseModel):
854+
z: int
855+
856+
M = M2
857+
858+
@validate_call
859+
def bar(m: M) -> M:
860+
return m
861+
862+
assert bar({'z': 1}) == M2(z=1)

0 commit comments

Comments
 (0)