|
29 | 29 | from typing_extensions import assert_type, get_type_hints, get_origin, get_args
|
30 | 30 | from typing_extensions import clear_overloads, get_overloads, overload
|
31 | 31 | from typing_extensions import NamedTuple
|
| 32 | +from _typed_dict_test_helper import FooGeneric |
32 | 33 |
|
33 | 34 | # Flags used to mark tests that only apply after a specific
|
34 | 35 | # version of the typing module.
|
@@ -1664,6 +1665,15 @@ class CustomProtocolWithoutInitB(Protocol):
|
1664 | 1665 | self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__)
|
1665 | 1666 |
|
1666 | 1667 |
|
| 1668 | +class Point2DGeneric(Generic[T], TypedDict): |
| 1669 | + a: T |
| 1670 | + b: T |
| 1671 | + |
| 1672 | + |
| 1673 | +class BarGeneric(FooGeneric[T], total=False): |
| 1674 | + b: int |
| 1675 | + |
| 1676 | + |
1667 | 1677 | class TypedDictTests(BaseTestCase):
|
1668 | 1678 |
|
1669 | 1679 | def test_basics_iterable_syntax(self):
|
@@ -1769,14 +1779,24 @@ def test_pickle(self):
|
1769 | 1779 | global EmpD # pickle wants to reference the class by name
|
1770 | 1780 | EmpD = TypedDict('EmpD', name=str, id=int)
|
1771 | 1781 | jane = EmpD({'name': 'jane', 'id': 37})
|
| 1782 | + point = Point2DGeneric(a=5.0, b=3.0) |
1772 | 1783 | for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
| 1784 | + # Test non-generic TypedDict |
1773 | 1785 | z = pickle.dumps(jane, proto)
|
1774 | 1786 | jane2 = pickle.loads(z)
|
1775 | 1787 | self.assertEqual(jane2, jane)
|
1776 | 1788 | self.assertEqual(jane2, {'name': 'jane', 'id': 37})
|
1777 | 1789 | ZZ = pickle.dumps(EmpD, proto)
|
1778 | 1790 | EmpDnew = pickle.loads(ZZ)
|
1779 | 1791 | self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane)
|
| 1792 | + # and generic TypedDict |
| 1793 | + y = pickle.dumps(point, proto) |
| 1794 | + point2 = pickle.loads(y) |
| 1795 | + self.assertEqual(point, point2) |
| 1796 | + self.assertEqual(point2, {'a': 5.0, 'b': 3.0}) |
| 1797 | + YY = pickle.dumps(Point2DGeneric, proto) |
| 1798 | + Point2DGenericNew = pickle.loads(YY) |
| 1799 | + self.assertEqual(Point2DGenericNew({'a': 5.0, 'b': 3.0}), point) |
1780 | 1800 |
|
1781 | 1801 | def test_optional(self):
|
1782 | 1802 | EmpD = TypedDict('EmpD', name=str, id=int)
|
@@ -1854,6 +1874,124 @@ class PointDict3D(PointDict2D, total=False):
|
1854 | 1874 | assert is_typeddict(PointDict2D) is True
|
1855 | 1875 | assert is_typeddict(PointDict3D) is True
|
1856 | 1876 |
|
| 1877 | + def test_get_type_hints_generic(self): |
| 1878 | + self.assertEqual( |
| 1879 | + get_type_hints(BarGeneric), |
| 1880 | + {'a': typing.Optional[T], 'b': int} |
| 1881 | + ) |
| 1882 | + |
| 1883 | + class FooBarGeneric(BarGeneric[int]): |
| 1884 | + c: str |
| 1885 | + |
| 1886 | + self.assertEqual( |
| 1887 | + get_type_hints(FooBarGeneric), |
| 1888 | + {'a': typing.Optional[T], 'b': int, 'c': str} |
| 1889 | + ) |
| 1890 | + |
| 1891 | + def test_generic_inheritance(self): |
| 1892 | + class A(TypedDict, Generic[T]): |
| 1893 | + a: T |
| 1894 | + |
| 1895 | + self.assertEqual(A.__bases__, (Generic, dict)) |
| 1896 | + self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T])) |
| 1897 | + self.assertEqual(A.__mro__, (A, Generic, dict, object)) |
| 1898 | + self.assertEqual(A.__parameters__, (T,)) |
| 1899 | + self.assertEqual(A[str].__parameters__, ()) |
| 1900 | + self.assertEqual(A[str].__args__, (str,)) |
| 1901 | + |
| 1902 | + class A2(Generic[T], TypedDict): |
| 1903 | + a: T |
| 1904 | + |
| 1905 | + self.assertEqual(A2.__bases__, (Generic, dict)) |
| 1906 | + self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict)) |
| 1907 | + self.assertEqual(A2.__mro__, (A2, Generic, dict, object)) |
| 1908 | + self.assertEqual(A2.__parameters__, (T,)) |
| 1909 | + self.assertEqual(A2[str].__parameters__, ()) |
| 1910 | + self.assertEqual(A2[str].__args__, (str,)) |
| 1911 | + |
| 1912 | + class B(A[KT], total=False): |
| 1913 | + b: KT |
| 1914 | + |
| 1915 | + self.assertEqual(B.__bases__, (Generic, dict)) |
| 1916 | + self.assertEqual(B.__orig_bases__, (A[KT],)) |
| 1917 | + self.assertEqual(B.__mro__, (B, Generic, dict, object)) |
| 1918 | + self.assertEqual(B.__parameters__, (KT,)) |
| 1919 | + self.assertEqual(B.__total__, False) |
| 1920 | + self.assertEqual(B.__optional_keys__, frozenset(['b'])) |
| 1921 | + self.assertEqual(B.__required_keys__, frozenset(['a'])) |
| 1922 | + |
| 1923 | + self.assertEqual(B[str].__parameters__, ()) |
| 1924 | + self.assertEqual(B[str].__args__, (str,)) |
| 1925 | + self.assertEqual(B[str].__origin__, B) |
| 1926 | + |
| 1927 | + class C(B[int]): |
| 1928 | + c: int |
| 1929 | + |
| 1930 | + self.assertEqual(C.__bases__, (Generic, dict)) |
| 1931 | + self.assertEqual(C.__orig_bases__, (B[int],)) |
| 1932 | + self.assertEqual(C.__mro__, (C, Generic, dict, object)) |
| 1933 | + self.assertEqual(C.__parameters__, ()) |
| 1934 | + self.assertEqual(C.__total__, True) |
| 1935 | + self.assertEqual(C.__optional_keys__, frozenset(['b'])) |
| 1936 | + self.assertEqual(C.__required_keys__, frozenset(['a', 'c'])) |
| 1937 | + assert C.__annotations__ == { |
| 1938 | + 'a': T, |
| 1939 | + 'b': KT, |
| 1940 | + 'c': int, |
| 1941 | + } |
| 1942 | + with self.assertRaises(TypeError): |
| 1943 | + C[str] |
| 1944 | + |
| 1945 | + |
| 1946 | + class Point3D(Point2DGeneric[T], Generic[T, KT]): |
| 1947 | + c: KT |
| 1948 | + |
| 1949 | + self.assertEqual(Point3D.__bases__, (Generic, dict)) |
| 1950 | + self.assertEqual(Point3D.__orig_bases__, (Point2DGeneric[T], Generic[T, KT])) |
| 1951 | + self.assertEqual(Point3D.__mro__, (Point3D, Generic, dict, object)) |
| 1952 | + self.assertEqual(Point3D.__parameters__, (T, KT)) |
| 1953 | + self.assertEqual(Point3D.__total__, True) |
| 1954 | + self.assertEqual(Point3D.__optional_keys__, frozenset()) |
| 1955 | + self.assertEqual(Point3D.__required_keys__, frozenset(['a', 'b', 'c'])) |
| 1956 | + assert Point3D.__annotations__ == { |
| 1957 | + 'a': T, |
| 1958 | + 'b': T, |
| 1959 | + 'c': KT, |
| 1960 | + } |
| 1961 | + self.assertEqual(Point3D[int, str].__origin__, Point3D) |
| 1962 | + |
| 1963 | + with self.assertRaises(TypeError): |
| 1964 | + Point3D[int] |
| 1965 | + |
| 1966 | + with self.assertRaises(TypeError): |
| 1967 | + class Point3D(Point2DGeneric[T], Generic[KT]): |
| 1968 | + c: KT |
| 1969 | + |
| 1970 | + def test_implicit_any_inheritance(self): |
| 1971 | + class A(TypedDict, Generic[T]): |
| 1972 | + a: T |
| 1973 | + |
| 1974 | + class B(A[KT], total=False): |
| 1975 | + b: KT |
| 1976 | + |
| 1977 | + class WithImplicitAny(B): |
| 1978 | + c: int |
| 1979 | + |
| 1980 | + self.assertEqual(WithImplicitAny.__bases__, (Generic, dict,)) |
| 1981 | + self.assertEqual(WithImplicitAny.__mro__, (WithImplicitAny, Generic, dict, object)) |
| 1982 | + # Consistent with GenericTests.test_implicit_any |
| 1983 | + self.assertEqual(WithImplicitAny.__parameters__, ()) |
| 1984 | + self.assertEqual(WithImplicitAny.__total__, True) |
| 1985 | + self.assertEqual(WithImplicitAny.__optional_keys__, frozenset(['b'])) |
| 1986 | + self.assertEqual(WithImplicitAny.__required_keys__, frozenset(['a', 'c'])) |
| 1987 | + assert WithImplicitAny.__annotations__ == { |
| 1988 | + 'a': T, |
| 1989 | + 'b': KT, |
| 1990 | + 'c': int, |
| 1991 | + } |
| 1992 | + with self.assertRaises(TypeError): |
| 1993 | + WithImplicitAny[str] |
| 1994 | + |
1857 | 1995 |
|
1858 | 1996 | class AnnotatedTests(BaseTestCase):
|
1859 | 1997 |
|
|
0 commit comments