Skip to content

Commit 074d053

Browse files
authored
Backport PEP-696 specialisation on Python >=3.11.1 (#397)
1 parent 23378be commit 074d053

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed

src/test_typing_extensions.py

+68
Original file line numberDiff line numberDiff line change
@@ -6402,6 +6402,34 @@ def test_typevartuple(self):
64026402
class A(Generic[Unpack[Ts]]): ...
64036403
Alias = Optional[Unpack[Ts]]
64046404

6405+
@skipIf(
6406+
sys.version_info < (3, 11, 1),
6407+
"Not yet backported for older versions of Python"
6408+
)
6409+
def test_typevartuple_specialization(self):
6410+
T = TypeVar("T")
6411+
Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]])
6412+
self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]])
6413+
class A(Generic[T, Unpack[Ts]]): ...
6414+
self.assertEqual(A[float].__args__, (float, str, int))
6415+
self.assertEqual(A[float, range].__args__, (float, range))
6416+
self.assertEqual(A[float, Unpack[tuple[int, ...]]].__args__, (float, Unpack[tuple[int, ...]]))
6417+
6418+
@skipIf(
6419+
sys.version_info < (3, 11, 1),
6420+
"Not yet backported for older versions of Python"
6421+
)
6422+
def test_typevar_and_typevartuple_specialization(self):
6423+
T = TypeVar("T")
6424+
U = TypeVar("U", default=float)
6425+
Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]])
6426+
self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]])
6427+
class A(Generic[T, U, Unpack[Ts]]): ...
6428+
self.assertEqual(A[int].__args__, (int, float, str, int))
6429+
self.assertEqual(A[int, str].__args__, (int, str, str, int))
6430+
self.assertEqual(A[int, str, range].__args__, (int, str, range))
6431+
self.assertEqual(A[int, str, Unpack[tuple[int, ...]]].__args__, (int, str, Unpack[tuple[int, ...]]))
6432+
64056433
def test_no_default_after_typevar_tuple(self):
64066434
T = TypeVar("T", default=int)
64076435
Ts = TypeVarTuple("Ts")
@@ -6487,6 +6515,46 @@ def test_allow_default_after_non_default_in_alias(self):
64876515
a4 = Callable[[Unpack[Ts]], T]
64886516
self.assertEqual(a4.__args__, (Unpack[Ts], T))
64896517

6518+
@skipIf(
6519+
sys.version_info < (3, 11, 1),
6520+
"Not yet backported for older versions of Python"
6521+
)
6522+
def test_paramspec_specialization(self):
6523+
T = TypeVar("T")
6524+
P = ParamSpec('P', default=[str, int])
6525+
self.assertEqual(P.__default__, [str, int])
6526+
class A(Generic[T, P]): ...
6527+
self.assertEqual(A[float].__args__, (float, (str, int)))
6528+
self.assertEqual(A[float, [range]].__args__, (float, (range,)))
6529+
6530+
@skipIf(
6531+
sys.version_info < (3, 11, 1),
6532+
"Not yet backported for older versions of Python"
6533+
)
6534+
def test_typevar_and_paramspec_specialization(self):
6535+
T = TypeVar("T")
6536+
U = TypeVar("U", default=float)
6537+
P = ParamSpec('P', default=[str, int])
6538+
self.assertEqual(P.__default__, [str, int])
6539+
class A(Generic[T, U, P]): ...
6540+
self.assertEqual(A[float].__args__, (float, float, (str, int)))
6541+
self.assertEqual(A[float, int].__args__, (float, int, (str, int)))
6542+
self.assertEqual(A[float, int, [range]].__args__, (float, int, (range,)))
6543+
6544+
@skipIf(
6545+
sys.version_info < (3, 11, 1),
6546+
"Not yet backported for older versions of Python"
6547+
)
6548+
def test_paramspec_and_typevar_specialization(self):
6549+
T = TypeVar("T")
6550+
P = ParamSpec('P', default=[str, int])
6551+
U = TypeVar("U", default=float)
6552+
self.assertEqual(P.__default__, [str, int])
6553+
class A(Generic[T, P, U]): ...
6554+
self.assertEqual(A[float].__args__, (float, (str, int), float))
6555+
self.assertEqual(A[float, [range]].__args__, (float, (range,), float))
6556+
self.assertEqual(A[float, [range], int].__args__, (float, (range,), int))
6557+
64906558

64916559
class NoDefaultTests(BaseTestCase):
64926560
@skip_if_py313_beta_1

src/typing_extensions.py

+97
Original file line numberDiff line numberDiff line change
@@ -1513,8 +1513,19 @@ def __new__(cls, name, *constraints, bound=None,
15131513
if infer_variance and (covariant or contravariant):
15141514
raise ValueError("Variance cannot be specified with infer_variance.")
15151515
typevar.__infer_variance__ = infer_variance
1516+
15161517
_set_default(typevar, default)
15171518
_set_module(typevar)
1519+
1520+
def _tvar_prepare_subst(alias, args):
1521+
if (
1522+
typevar.has_default()
1523+
and alias.__parameters__.index(typevar) == len(args)
1524+
):
1525+
args += (typevar.__default__,)
1526+
return args
1527+
1528+
typevar.__typing_prepare_subst__ = _tvar_prepare_subst
15181529
return typevar
15191530

15201531
def __init_subclass__(cls) -> None:
@@ -1613,6 +1624,24 @@ def __new__(cls, name, *, bound=None,
16131624

16141625
_set_default(paramspec, default)
16151626
_set_module(paramspec)
1627+
1628+
def _paramspec_prepare_subst(alias, args):
1629+
params = alias.__parameters__
1630+
i = params.index(paramspec)
1631+
if i == len(args) and paramspec.has_default():
1632+
args = [*args, paramspec.__default__]
1633+
if i >= len(args):
1634+
raise TypeError(f"Too few arguments for {alias}")
1635+
# Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612.
1636+
if len(params) == 1 and not typing._is_param_expr(args[0]):
1637+
assert i == 0
1638+
args = (args,)
1639+
# Convert lists to tuples to help other libraries cache the results.
1640+
elif isinstance(args[i], list):
1641+
args = (*args[:i], tuple(args[i]), *args[i + 1:])
1642+
return args
1643+
1644+
paramspec.__typing_prepare_subst__ = _paramspec_prepare_subst
16161645
return paramspec
16171646

16181647
def __init_subclass__(cls) -> None:
@@ -2311,6 +2340,17 @@ def __init__(self, getitem):
23112340
class _UnpackAlias(typing._GenericAlias, _root=True):
23122341
__class__ = typing.TypeVar
23132342

2343+
@property
2344+
def __typing_unpacked_tuple_args__(self):
2345+
assert self.__origin__ is Unpack
2346+
assert len(self.__args__) == 1
2347+
arg, = self.__args__
2348+
if isinstance(arg, (typing._GenericAlias, _types.GenericAlias)):
2349+
if arg.__origin__ is not tuple:
2350+
raise TypeError("Unpack[...] must be used with a tuple type")
2351+
return arg.__args__
2352+
return None
2353+
23142354
@_UnpackSpecialForm
23152355
def Unpack(self, parameters):
23162356
item = typing._type_check(parameters, f'{self._name} accepts only a single type.')
@@ -2340,6 +2380,16 @@ def _is_unpack(obj):
23402380

23412381
elif hasattr(typing, "TypeVarTuple"): # 3.11+
23422382

2383+
def _unpack_args(*args):
2384+
newargs = []
2385+
for arg in args:
2386+
subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
2387+
if subargs is not None and not (subargs and subargs[-1] is ...):
2388+
newargs.extend(subargs)
2389+
else:
2390+
newargs.append(arg)
2391+
return newargs
2392+
23432393
# Add default parameter - PEP 696
23442394
class TypeVarTuple(metaclass=_TypeVarLikeMeta):
23452395
"""Type variable tuple."""
@@ -2350,6 +2400,53 @@ def __new__(cls, name, *, default=NoDefault):
23502400
tvt = typing.TypeVarTuple(name)
23512401
_set_default(tvt, default)
23522402
_set_module(tvt)
2403+
2404+
def _typevartuple_prepare_subst(alias, args):
2405+
params = alias.__parameters__
2406+
typevartuple_index = params.index(tvt)
2407+
for param in params[typevartuple_index + 1:]:
2408+
if isinstance(param, TypeVarTuple):
2409+
raise TypeError(
2410+
f"More than one TypeVarTuple parameter in {alias}"
2411+
)
2412+
2413+
alen = len(args)
2414+
plen = len(params)
2415+
left = typevartuple_index
2416+
right = plen - typevartuple_index - 1
2417+
var_tuple_index = None
2418+
fillarg = None
2419+
for k, arg in enumerate(args):
2420+
if not isinstance(arg, type):
2421+
subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
2422+
if subargs and len(subargs) == 2 and subargs[-1] is ...:
2423+
if var_tuple_index is not None:
2424+
raise TypeError(
2425+
"More than one unpacked "
2426+
"arbitrary-length tuple argument"
2427+
)
2428+
var_tuple_index = k
2429+
fillarg = subargs[0]
2430+
if var_tuple_index is not None:
2431+
left = min(left, var_tuple_index)
2432+
right = min(right, alen - var_tuple_index - 1)
2433+
elif left + right > alen:
2434+
raise TypeError(f"Too few arguments for {alias};"
2435+
f" actual {alen}, expected at least {plen - 1}")
2436+
if left == alen - right and tvt.has_default():
2437+
replacement = _unpack_args(tvt.__default__)
2438+
else:
2439+
replacement = args[left: alen - right]
2440+
2441+
return (
2442+
*args[:left],
2443+
*([fillarg] * (typevartuple_index - left)),
2444+
replacement,
2445+
*([fillarg] * (plen - right - left - typevartuple_index - 1)),
2446+
*args[alen - right:],
2447+
)
2448+
2449+
tvt.__typing_prepare_subst__ = _typevartuple_prepare_subst
23532450
return tvt
23542451

23552452
def __init_subclass__(self, *args, **kwds):

0 commit comments

Comments
 (0)