Skip to content

Commit 7c28357

Browse files
Add a backport of generic NamedTuples (#44)
Co-authored-by: Jelle Zijlstra <[email protected]>
1 parent 7198c63 commit 7c28357

File tree

4 files changed

+403
-2
lines changed

4 files changed

+403
-2
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Unreleased
2+
3+
- Add `typing_extensions.NamedTuple`, allowing for generic `NamedTuple`s on
4+
Python <3.11 (backport from python/cpython#92027, by Serhiy Storchaka). Patch
5+
by Alex Waygood (@AlexWaygood).
6+
17
# Release 4.2.0 (April 17, 2022)
28

39
- Re-export `typing.Unpack` and `typing.TypeVarTuple` on Python 3.11.

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ This module currently contains the following:
9696
- `Counter`
9797
- `DefaultDict`
9898
- `Deque`
99+
- `NamedTuple`
99100
- `NewType`
100101
- `NoReturn`
101102
- `overload`
@@ -121,6 +122,8 @@ Certain objects were changed after they were added to `typing`, and
121122
introspectable at runtime. In order to access overloads with
122123
`typing_extensions.get_overloads()`, you must use
123124
`@typing_extensions.overload`.
125+
- `NamedTuple` was changed in Python 3.11 to allow for multiple inheritance
126+
with `typing.Generic`.
124127

125128
There are a few types whose interface was modified between different
126129
versions of typing. For example, `typing.Sequence` was modified to

src/test_typing_extensions.py

+304-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import collections
66
from collections import defaultdict
77
import collections.abc
8+
import copy
89
from functools import lru_cache
910
import inspect
1011
import pickle
@@ -17,7 +18,7 @@
1718
from typing import TypeVar, Optional, Union, Any, AnyStr
1819
from typing import T, KT, VT # Not in __all__.
1920
from typing import Tuple, List, Dict, Iterable, Iterator, Callable
20-
from typing import Generic, NamedTuple
21+
from typing import Generic
2122
from typing import no_type_check
2223
import typing_extensions
2324
from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self
@@ -27,10 +28,12 @@
2728
from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString
2829
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
2930
from typing_extensions import clear_overloads, get_overloads, overload
31+
from typing_extensions import NamedTuple
3032

3133
# Flags used to mark tests that only apply after a specific
3234
# version of the typing module.
3335
TYPING_3_8_0 = sys.version_info[:3] >= (3, 8, 0)
36+
TYPING_3_9_0 = sys.version_info[:3] >= (3, 9, 0)
3437
TYPING_3_10_0 = sys.version_info[:3] >= (3, 10, 0)
3538

3639
# 3.11 makes runtime type checks (_type_check) more lenient.
@@ -2874,7 +2877,7 @@ def test_typing_extensions_defers_when_possible(self):
28742877
if sys.version_info < (3, 10):
28752878
exclude |= {'get_args', 'get_origin'}
28762879
if sys.version_info < (3, 11):
2877-
exclude.add('final')
2880+
exclude |= {'final', 'NamedTuple'}
28782881
for item in typing_extensions.__all__:
28792882
if item not in exclude and hasattr(typing, item):
28802883
self.assertIs(
@@ -2892,6 +2895,305 @@ def test_typing_extensions_compiles_with_opt(self):
28922895
self.fail('Module does not compile with optimize=2 (-OO flag).')
28932896

28942897

2898+
class CoolEmployee(NamedTuple):
2899+
name: str
2900+
cool: int
2901+
2902+
2903+
class CoolEmployeeWithDefault(NamedTuple):
2904+
name: str
2905+
cool: int = 0
2906+
2907+
2908+
class XMeth(NamedTuple):
2909+
x: int
2910+
2911+
def double(self):
2912+
return 2 * self.x
2913+
2914+
2915+
class XRepr(NamedTuple):
2916+
x: int
2917+
y: int = 1
2918+
2919+
def __str__(self):
2920+
return f'{self.x} -> {self.y}'
2921+
2922+
def __add__(self, other):
2923+
return 0
2924+
2925+
2926+
@skipIf(TYPING_3_11_0, "These invariants should all be tested upstream on 3.11+")
2927+
class NamedTupleTests(BaseTestCase):
2928+
class NestedEmployee(NamedTuple):
2929+
name: str
2930+
cool: int
2931+
2932+
def test_basics(self):
2933+
Emp = NamedTuple('Emp', [('name', str), ('id', int)])
2934+
self.assertIsSubclass(Emp, tuple)
2935+
joe = Emp('Joe', 42)
2936+
jim = Emp(name='Jim', id=1)
2937+
self.assertIsInstance(joe, Emp)
2938+
self.assertIsInstance(joe, tuple)
2939+
self.assertEqual(joe.name, 'Joe')
2940+
self.assertEqual(joe.id, 42)
2941+
self.assertEqual(jim.name, 'Jim')
2942+
self.assertEqual(jim.id, 1)
2943+
self.assertEqual(Emp.__name__, 'Emp')
2944+
self.assertEqual(Emp._fields, ('name', 'id'))
2945+
self.assertEqual(Emp.__annotations__,
2946+
collections.OrderedDict([('name', str), ('id', int)]))
2947+
2948+
def test_annotation_usage(self):
2949+
tim = CoolEmployee('Tim', 9000)
2950+
self.assertIsInstance(tim, CoolEmployee)
2951+
self.assertIsInstance(tim, tuple)
2952+
self.assertEqual(tim.name, 'Tim')
2953+
self.assertEqual(tim.cool, 9000)
2954+
self.assertEqual(CoolEmployee.__name__, 'CoolEmployee')
2955+
self.assertEqual(CoolEmployee._fields, ('name', 'cool'))
2956+
self.assertEqual(CoolEmployee.__annotations__,
2957+
collections.OrderedDict(name=str, cool=int))
2958+
2959+
def test_annotation_usage_with_default(self):
2960+
jelle = CoolEmployeeWithDefault('Jelle')
2961+
self.assertIsInstance(jelle, CoolEmployeeWithDefault)
2962+
self.assertIsInstance(jelle, tuple)
2963+
self.assertEqual(jelle.name, 'Jelle')
2964+
self.assertEqual(jelle.cool, 0)
2965+
cooler_employee = CoolEmployeeWithDefault('Sjoerd', 1)
2966+
self.assertEqual(cooler_employee.cool, 1)
2967+
2968+
self.assertEqual(CoolEmployeeWithDefault.__name__, 'CoolEmployeeWithDefault')
2969+
self.assertEqual(CoolEmployeeWithDefault._fields, ('name', 'cool'))
2970+
self.assertEqual(CoolEmployeeWithDefault.__annotations__,
2971+
dict(name=str, cool=int))
2972+
2973+
with self.assertRaisesRegex(
2974+
TypeError,
2975+
'Non-default namedtuple field y cannot follow default field x'
2976+
):
2977+
class NonDefaultAfterDefault(NamedTuple):
2978+
x: int = 3
2979+
y: int
2980+
2981+
@skipUnless(
2982+
(
2983+
TYPING_3_8_0
2984+
or hasattr(CoolEmployeeWithDefault, '_field_defaults')
2985+
),
2986+
'"_field_defaults" attribute was added in a micro version of 3.7'
2987+
)
2988+
def test_field_defaults(self):
2989+
self.assertEqual(CoolEmployeeWithDefault._field_defaults, dict(cool=0))
2990+
2991+
def test_annotation_usage_with_methods(self):
2992+
self.assertEqual(XMeth(1).double(), 2)
2993+
self.assertEqual(XMeth(42).x, XMeth(42)[0])
2994+
self.assertEqual(str(XRepr(42)), '42 -> 1')
2995+
self.assertEqual(XRepr(1, 2) + XRepr(3), 0)
2996+
2997+
bad_overwrite_error_message = 'Cannot overwrite NamedTuple attribute'
2998+
2999+
with self.assertRaisesRegex(AttributeError, bad_overwrite_error_message):
3000+
class XMethBad(NamedTuple):
3001+
x: int
3002+
def _fields(self):
3003+
return 'no chance for this'
3004+
3005+
with self.assertRaisesRegex(AttributeError, bad_overwrite_error_message):
3006+
class XMethBad2(NamedTuple):
3007+
x: int
3008+
def _source(self):
3009+
return 'no chance for this as well'
3010+
3011+
def test_multiple_inheritance(self):
3012+
class A:
3013+
pass
3014+
with self.assertRaisesRegex(
3015+
TypeError,
3016+
'can only inherit from a NamedTuple type and Generic'
3017+
):
3018+
class X(NamedTuple, A):
3019+
x: int
3020+
3021+
with self.assertRaisesRegex(
3022+
TypeError,
3023+
'can only inherit from a NamedTuple type and Generic'
3024+
):
3025+
class X(NamedTuple, tuple):
3026+
x: int
3027+
3028+
with self.assertRaisesRegex(TypeError, 'duplicate base class'):
3029+
class X(NamedTuple, NamedTuple):
3030+
x: int
3031+
3032+
class A(NamedTuple):
3033+
x: int
3034+
with self.assertRaisesRegex(
3035+
TypeError,
3036+
'can only inherit from a NamedTuple type and Generic'
3037+
):
3038+
class X(NamedTuple, A):
3039+
y: str
3040+
3041+
def test_generic(self):
3042+
class X(NamedTuple, Generic[T]):
3043+
x: T
3044+
self.assertEqual(X.__bases__, (tuple, Generic))
3045+
self.assertEqual(X.__orig_bases__, (NamedTuple, Generic[T]))
3046+
self.assertEqual(X.__mro__, (X, tuple, Generic, object))
3047+
3048+
class Y(Generic[T], NamedTuple):
3049+
x: T
3050+
self.assertEqual(Y.__bases__, (Generic, tuple))
3051+
self.assertEqual(Y.__orig_bases__, (Generic[T], NamedTuple))
3052+
self.assertEqual(Y.__mro__, (Y, Generic, tuple, object))
3053+
3054+
for G in X, Y:
3055+
with self.subTest(type=G):
3056+
self.assertEqual(G.__parameters__, (T,))
3057+
A = G[int]
3058+
self.assertIs(A.__origin__, G)
3059+
self.assertEqual(A.__args__, (int,))
3060+
self.assertEqual(A.__parameters__, ())
3061+
3062+
a = A(3)
3063+
self.assertIs(type(a), G)
3064+
self.assertEqual(a.x, 3)
3065+
3066+
with self.assertRaisesRegex(TypeError, 'Too many parameters'):
3067+
G[int, str]
3068+
3069+
@skipUnless(TYPING_3_9_0, "tuple.__class_getitem__ was added in 3.9")
3070+
def test_non_generic_subscript_py39_plus(self):
3071+
# For backward compatibility, subscription works
3072+
# on arbitrary NamedTuple types.
3073+
class Group(NamedTuple):
3074+
key: T
3075+
group: list[T]
3076+
A = Group[int]
3077+
self.assertEqual(A.__origin__, Group)
3078+
self.assertEqual(A.__parameters__, ())
3079+
self.assertEqual(A.__args__, (int,))
3080+
a = A(1, [2])
3081+
self.assertIs(type(a), Group)
3082+
self.assertEqual(a, (1, [2]))
3083+
3084+
@skipIf(TYPING_3_9_0, "Test isn't relevant to 3.9+")
3085+
def test_non_generic_subscript_error_message_py38_minus(self):
3086+
class Group(NamedTuple):
3087+
key: T
3088+
group: List[T]
3089+
3090+
with self.assertRaisesRegex(TypeError, 'not subscriptable'):
3091+
Group[int]
3092+
3093+
for attr in ('__args__', '__origin__', '__parameters__'):
3094+
with self.subTest(attr=attr):
3095+
self.assertFalse(hasattr(Group, attr))
3096+
3097+
def test_namedtuple_keyword_usage(self):
3098+
LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int)
3099+
nick = LocalEmployee('Nick', 25)
3100+
self.assertIsInstance(nick, tuple)
3101+
self.assertEqual(nick.name, 'Nick')
3102+
self.assertEqual(LocalEmployee.__name__, 'LocalEmployee')
3103+
self.assertEqual(LocalEmployee._fields, ('name', 'age'))
3104+
self.assertEqual(LocalEmployee.__annotations__, dict(name=str, age=int))
3105+
with self.assertRaisesRegex(
3106+
TypeError,
3107+
'Either list of fields or keywords can be provided to NamedTuple, not both'
3108+
):
3109+
NamedTuple('Name', [('x', int)], y=str)
3110+
3111+
def test_namedtuple_special_keyword_names(self):
3112+
NT = NamedTuple("NT", cls=type, self=object, typename=str, fields=list)
3113+
self.assertEqual(NT.__name__, 'NT')
3114+
self.assertEqual(NT._fields, ('cls', 'self', 'typename', 'fields'))
3115+
a = NT(cls=str, self=42, typename='foo', fields=[('bar', tuple)])
3116+
self.assertEqual(a.cls, str)
3117+
self.assertEqual(a.self, 42)
3118+
self.assertEqual(a.typename, 'foo')
3119+
self.assertEqual(a.fields, [('bar', tuple)])
3120+
3121+
def test_empty_namedtuple(self):
3122+
NT = NamedTuple('NT')
3123+
3124+
class CNT(NamedTuple):
3125+
pass # empty body
3126+
3127+
for struct in [NT, CNT]:
3128+
with self.subTest(struct=struct):
3129+
self.assertEqual(struct._fields, ())
3130+
self.assertEqual(struct.__annotations__, {})
3131+
self.assertIsInstance(struct(), struct)
3132+
# Attribute was added in a micro version of 3.7
3133+
# and is tested more fully elsewhere
3134+
if hasattr(struct, "_field_defaults"):
3135+
self.assertEqual(struct._field_defaults, {})
3136+
3137+
def test_namedtuple_errors(self):
3138+
with self.assertRaises(TypeError):
3139+
NamedTuple.__new__()
3140+
with self.assertRaises(TypeError):
3141+
NamedTuple()
3142+
with self.assertRaises(TypeError):
3143+
NamedTuple('Emp', [('name', str)], None)
3144+
with self.assertRaisesRegex(ValueError, 'cannot start with an underscore'):
3145+
NamedTuple('Emp', [('_name', str)])
3146+
with self.assertRaises(TypeError):
3147+
NamedTuple(typename='Emp', name=str, id=int)
3148+
3149+
def test_copy_and_pickle(self):
3150+
global Emp # pickle wants to reference the class by name
3151+
Emp = NamedTuple('Emp', [('name', str), ('cool', int)])
3152+
for cls in Emp, CoolEmployee, self.NestedEmployee:
3153+
with self.subTest(cls=cls):
3154+
jane = cls('jane', 37)
3155+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3156+
z = pickle.dumps(jane, proto)
3157+
jane2 = pickle.loads(z)
3158+
self.assertEqual(jane2, jane)
3159+
self.assertIsInstance(jane2, cls)
3160+
3161+
jane2 = copy.copy(jane)
3162+
self.assertEqual(jane2, jane)
3163+
self.assertIsInstance(jane2, cls)
3164+
3165+
jane2 = copy.deepcopy(jane)
3166+
self.assertEqual(jane2, jane)
3167+
self.assertIsInstance(jane2, cls)
3168+
3169+
def test_docstring(self):
3170+
self.assertEqual(NamedTuple.__doc__, typing.NamedTuple.__doc__)
3171+
self.assertIsInstance(NamedTuple.__doc__, str)
3172+
3173+
@skipUnless(TYPING_3_8_0, "NamedTuple had a bad signature on <=3.7")
3174+
def test_signature_is_same_as_typing_NamedTuple(self):
3175+
self.assertEqual(inspect.signature(NamedTuple), inspect.signature(typing.NamedTuple))
3176+
3177+
@skipIf(TYPING_3_8_0, "tests are only relevant to <=3.7")
3178+
def test_signature_on_37(self):
3179+
self.assertIsInstance(inspect.signature(NamedTuple), inspect.Signature)
3180+
self.assertFalse(hasattr(NamedTuple, "__text_signature__"))
3181+
3182+
@skipUnless(TYPING_3_9_0, "NamedTuple was a class on 3.8 and lower")
3183+
def test_same_as_typing_NamedTuple_39_plus(self):
3184+
self.assertEqual(
3185+
set(dir(NamedTuple)),
3186+
set(dir(typing.NamedTuple)) | {"__text_signature__"}
3187+
)
3188+
self.assertIs(type(NamedTuple), type(typing.NamedTuple))
3189+
3190+
@skipIf(TYPING_3_9_0, "tests are only relevant to <=3.8")
3191+
def test_same_as_typing_NamedTuple_38_minus(self):
3192+
self.assertEqual(
3193+
self.NestedEmployee.__annotations__,
3194+
self.NestedEmployee._field_types
3195+
)
3196+
28953197

28963198
if __name__ == '__main__':
28973199
main()

0 commit comments

Comments
 (0)