Skip to content

Commit 1baf0a5

Browse files
authored
Backport generic TypedDicts (#46)
1 parent 7c28357 commit 1baf0a5

5 files changed

+201
-32
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
- Add `typing_extensions.NamedTuple`, allowing for generic `NamedTuple`s on
44
Python <3.11 (backport from python/cpython#92027, by Serhiy Storchaka). Patch
55
by Alex Waygood (@AlexWaygood).
6+
- Adjust `typing_extensions.TypedDict` to allow for generic `TypedDict`s on
7+
Python <3.11 (backport from python/cpython#27663, by Samodya Abey). Patch by
8+
Alex Waygood (@AlexWaygood).
69

710
# Release 4.2.0 (April 17, 2022)
811

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ Certain objects were changed after they were added to `typing`, and
113113
- `TypedDict` does not store runtime information
114114
about which (if any) keys are non-required in Python 3.8, and does not
115115
honor the `total` keyword with old-style `TypedDict()` in Python
116-
3.9.0 and 3.9.1.
116+
3.9.0 and 3.9.1. `TypedDict` also does not support multiple inheritance
117+
with `typing.Generic` on Python <3.11.
117118
- `get_origin` and `get_args` lack support for `Annotated` in
118119
Python 3.8 and lack support for `ParamSpecArgs` and `ParamSpecKwargs`
119120
in 3.9.

src/_typed_dict_test_helper.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from __future__ import annotations
2+
3+
from typing import Generic, Optional, T
4+
from typing_extensions import TypedDict
5+
6+
7+
class FooGeneric(TypedDict, Generic[T]):
8+
a: Optional[T]

src/test_typing_extensions.py

+138
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
3030
from typing_extensions import clear_overloads, get_overloads, overload
3131
from typing_extensions import NamedTuple
32+
from _typed_dict_test_helper import FooGeneric
3233

3334
# Flags used to mark tests that only apply after a specific
3435
# version of the typing module.
@@ -1664,6 +1665,15 @@ class CustomProtocolWithoutInitB(Protocol):
16641665
self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__)
16651666

16661667

1668+
class Point2DGeneric(Generic[T], TypedDict):
1669+
a: T
1670+
b: T
1671+
1672+
1673+
class BarGeneric(FooGeneric[T], total=False):
1674+
b: int
1675+
1676+
16671677
class TypedDictTests(BaseTestCase):
16681678

16691679
def test_basics_iterable_syntax(self):
@@ -1769,14 +1779,24 @@ def test_pickle(self):
17691779
global EmpD # pickle wants to reference the class by name
17701780
EmpD = TypedDict('EmpD', name=str, id=int)
17711781
jane = EmpD({'name': 'jane', 'id': 37})
1782+
point = Point2DGeneric(a=5.0, b=3.0)
17721783
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1784+
# Test non-generic TypedDict
17731785
z = pickle.dumps(jane, proto)
17741786
jane2 = pickle.loads(z)
17751787
self.assertEqual(jane2, jane)
17761788
self.assertEqual(jane2, {'name': 'jane', 'id': 37})
17771789
ZZ = pickle.dumps(EmpD, proto)
17781790
EmpDnew = pickle.loads(ZZ)
17791791
self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane)
1792+
# and generic TypedDict
1793+
y = pickle.dumps(point, proto)
1794+
point2 = pickle.loads(y)
1795+
self.assertEqual(point, point2)
1796+
self.assertEqual(point2, {'a': 5.0, 'b': 3.0})
1797+
YY = pickle.dumps(Point2DGeneric, proto)
1798+
Point2DGenericNew = pickle.loads(YY)
1799+
self.assertEqual(Point2DGenericNew({'a': 5.0, 'b': 3.0}), point)
17801800

17811801
def test_optional(self):
17821802
EmpD = TypedDict('EmpD', name=str, id=int)
@@ -1854,6 +1874,124 @@ class PointDict3D(PointDict2D, total=False):
18541874
assert is_typeddict(PointDict2D) is True
18551875
assert is_typeddict(PointDict3D) is True
18561876

1877+
def test_get_type_hints_generic(self):
1878+
self.assertEqual(
1879+
get_type_hints(BarGeneric),
1880+
{'a': typing.Optional[T], 'b': int}
1881+
)
1882+
1883+
class FooBarGeneric(BarGeneric[int]):
1884+
c: str
1885+
1886+
self.assertEqual(
1887+
get_type_hints(FooBarGeneric),
1888+
{'a': typing.Optional[T], 'b': int, 'c': str}
1889+
)
1890+
1891+
def test_generic_inheritance(self):
1892+
class A(TypedDict, Generic[T]):
1893+
a: T
1894+
1895+
self.assertEqual(A.__bases__, (Generic, dict))
1896+
self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T]))
1897+
self.assertEqual(A.__mro__, (A, Generic, dict, object))
1898+
self.assertEqual(A.__parameters__, (T,))
1899+
self.assertEqual(A[str].__parameters__, ())
1900+
self.assertEqual(A[str].__args__, (str,))
1901+
1902+
class A2(Generic[T], TypedDict):
1903+
a: T
1904+
1905+
self.assertEqual(A2.__bases__, (Generic, dict))
1906+
self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict))
1907+
self.assertEqual(A2.__mro__, (A2, Generic, dict, object))
1908+
self.assertEqual(A2.__parameters__, (T,))
1909+
self.assertEqual(A2[str].__parameters__, ())
1910+
self.assertEqual(A2[str].__args__, (str,))
1911+
1912+
class B(A[KT], total=False):
1913+
b: KT
1914+
1915+
self.assertEqual(B.__bases__, (Generic, dict))
1916+
self.assertEqual(B.__orig_bases__, (A[KT],))
1917+
self.assertEqual(B.__mro__, (B, Generic, dict, object))
1918+
self.assertEqual(B.__parameters__, (KT,))
1919+
self.assertEqual(B.__total__, False)
1920+
self.assertEqual(B.__optional_keys__, frozenset(['b']))
1921+
self.assertEqual(B.__required_keys__, frozenset(['a']))
1922+
1923+
self.assertEqual(B[str].__parameters__, ())
1924+
self.assertEqual(B[str].__args__, (str,))
1925+
self.assertEqual(B[str].__origin__, B)
1926+
1927+
class C(B[int]):
1928+
c: int
1929+
1930+
self.assertEqual(C.__bases__, (Generic, dict))
1931+
self.assertEqual(C.__orig_bases__, (B[int],))
1932+
self.assertEqual(C.__mro__, (C, Generic, dict, object))
1933+
self.assertEqual(C.__parameters__, ())
1934+
self.assertEqual(C.__total__, True)
1935+
self.assertEqual(C.__optional_keys__, frozenset(['b']))
1936+
self.assertEqual(C.__required_keys__, frozenset(['a', 'c']))
1937+
assert C.__annotations__ == {
1938+
'a': T,
1939+
'b': KT,
1940+
'c': int,
1941+
}
1942+
with self.assertRaises(TypeError):
1943+
C[str]
1944+
1945+
1946+
class Point3D(Point2DGeneric[T], Generic[T, KT]):
1947+
c: KT
1948+
1949+
self.assertEqual(Point3D.__bases__, (Generic, dict))
1950+
self.assertEqual(Point3D.__orig_bases__, (Point2DGeneric[T], Generic[T, KT]))
1951+
self.assertEqual(Point3D.__mro__, (Point3D, Generic, dict, object))
1952+
self.assertEqual(Point3D.__parameters__, (T, KT))
1953+
self.assertEqual(Point3D.__total__, True)
1954+
self.assertEqual(Point3D.__optional_keys__, frozenset())
1955+
self.assertEqual(Point3D.__required_keys__, frozenset(['a', 'b', 'c']))
1956+
assert Point3D.__annotations__ == {
1957+
'a': T,
1958+
'b': T,
1959+
'c': KT,
1960+
}
1961+
self.assertEqual(Point3D[int, str].__origin__, Point3D)
1962+
1963+
with self.assertRaises(TypeError):
1964+
Point3D[int]
1965+
1966+
with self.assertRaises(TypeError):
1967+
class Point3D(Point2DGeneric[T], Generic[KT]):
1968+
c: KT
1969+
1970+
def test_implicit_any_inheritance(self):
1971+
class A(TypedDict, Generic[T]):
1972+
a: T
1973+
1974+
class B(A[KT], total=False):
1975+
b: KT
1976+
1977+
class WithImplicitAny(B):
1978+
c: int
1979+
1980+
self.assertEqual(WithImplicitAny.__bases__, (Generic, dict,))
1981+
self.assertEqual(WithImplicitAny.__mro__, (WithImplicitAny, Generic, dict, object))
1982+
# Consistent with GenericTests.test_implicit_any
1983+
self.assertEqual(WithImplicitAny.__parameters__, ())
1984+
self.assertEqual(WithImplicitAny.__total__, True)
1985+
self.assertEqual(WithImplicitAny.__optional_keys__, frozenset(['b']))
1986+
self.assertEqual(WithImplicitAny.__required_keys__, frozenset(['a', 'c']))
1987+
assert WithImplicitAny.__annotations__ == {
1988+
'a': T,
1989+
'b': KT,
1990+
'c': int,
1991+
}
1992+
with self.assertRaises(TypeError):
1993+
WithImplicitAny[str]
1994+
18571995

18581996
class AnnotatedTests(BaseTestCase):
18591997

src/typing_extensions.py

+50-31
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,46 @@ def _is_callable_members_only(cls):
381381
return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
382382

383383

384+
def _maybe_adjust_parameters(cls):
385+
"""Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__.
386+
387+
The contents of this function are very similar
388+
to logic found in typing.Generic.__init_subclass__
389+
on the CPython main branch.
390+
"""
391+
tvars = []
392+
if '__orig_bases__' in cls.__dict__:
393+
tvars = typing._collect_type_vars(cls.__orig_bases__)
394+
# Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn].
395+
# If found, tvars must be a subset of it.
396+
# If not found, tvars is it.
397+
# Also check for and reject plain Generic,
398+
# and reject multiple Generic[...] and/or Protocol[...].
399+
gvars = None
400+
for base in cls.__orig_bases__:
401+
if (isinstance(base, typing._GenericAlias) and
402+
base.__origin__ in (typing.Generic, Protocol)):
403+
# for error messages
404+
the_base = base.__origin__.__name__
405+
if gvars is not None:
406+
raise TypeError(
407+
"Cannot inherit from Generic[...]"
408+
" and/or Protocol[...] multiple types.")
409+
gvars = base.__parameters__
410+
if gvars is None:
411+
gvars = tvars
412+
else:
413+
tvarset = set(tvars)
414+
gvarset = set(gvars)
415+
if not tvarset <= gvarset:
416+
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
417+
s_args = ', '.join(str(g) for g in gvars)
418+
raise TypeError(f"Some type variables ({s_vars}) are"
419+
f" not listed in {the_base}[{s_args}]")
420+
tvars = gvars
421+
cls.__parameters__ = tuple(tvars)
422+
423+
384424
# 3.8+
385425
if hasattr(typing, 'Protocol'):
386426
Protocol = typing.Protocol
@@ -477,43 +517,13 @@ def __class_getitem__(cls, params):
477517
return typing._GenericAlias(cls, params)
478518

479519
def __init_subclass__(cls, *args, **kwargs):
480-
tvars = []
481520
if '__orig_bases__' in cls.__dict__:
482521
error = typing.Generic in cls.__orig_bases__
483522
else:
484523
error = typing.Generic in cls.__bases__
485524
if error:
486525
raise TypeError("Cannot inherit from plain Generic")
487-
if '__orig_bases__' in cls.__dict__:
488-
tvars = typing._collect_type_vars(cls.__orig_bases__)
489-
# Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn].
490-
# If found, tvars must be a subset of it.
491-
# If not found, tvars is it.
492-
# Also check for and reject plain Generic,
493-
# and reject multiple Generic[...] and/or Protocol[...].
494-
gvars = None
495-
for base in cls.__orig_bases__:
496-
if (isinstance(base, typing._GenericAlias) and
497-
base.__origin__ in (typing.Generic, Protocol)):
498-
# for error messages
499-
the_base = base.__origin__.__name__
500-
if gvars is not None:
501-
raise TypeError(
502-
"Cannot inherit from Generic[...]"
503-
" and/or Protocol[...] multiple types.")
504-
gvars = base.__parameters__
505-
if gvars is None:
506-
gvars = tvars
507-
else:
508-
tvarset = set(tvars)
509-
gvarset = set(gvars)
510-
if not tvarset <= gvarset:
511-
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
512-
s_args = ', '.join(str(g) for g in gvars)
513-
raise TypeError(f"Some type variables ({s_vars}) are"
514-
f" not listed in {the_base}[{s_args}]")
515-
tvars = gvars
516-
cls.__parameters__ = tuple(tvars)
526+
_maybe_adjust_parameters(cls)
517527

518528
# Determine if this is a protocol or a concrete subclass.
519529
if not cls.__dict__.get('_is_protocol', None):
@@ -614,6 +624,7 @@ def __index__(self) -> int:
614624
# keyword with old-style TypedDict(). See https://bugs.python.org/issue42059
615625
# The standard library TypedDict below Python 3.11 does not store runtime
616626
# information about optional and required keys when using Required or NotRequired.
627+
# Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11.
617628
TypedDict = typing.TypedDict
618629
_TypedDictMeta = typing._TypedDictMeta
619630
is_typeddict = typing.is_typeddict
@@ -696,8 +707,16 @@ def __new__(cls, name, bases, ns, total=True):
696707
# Subclasses and instances of TypedDict return actual dictionaries
697708
# via _dict_new.
698709
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
710+
# Don't insert typing.Generic into __bases__ here,
711+
# or Generic.__init_subclass__ will raise TypeError
712+
# in the super().__new__() call.
713+
# Instead, monkey-patch __bases__ onto the class after it's been created.
699714
tp_dict = super().__new__(cls, name, (dict,), ns)
700715

716+
if any(issubclass(base, typing.Generic) for base in bases):
717+
tp_dict.__bases__ = (typing.Generic, dict)
718+
_maybe_adjust_parameters(tp_dict)
719+
701720
annotations = {}
702721
own_annotations = ns.get('__annotations__', {})
703722
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"

0 commit comments

Comments
 (0)