Skip to content

Commit ff394e9

Browse files
committed
ENH: annotate *args
This is a bit clumsy: func(*args : Annotation) gives a single annotation for a runtime-determined number of arguments. There is no way to annotate individual elements of *args AFAICS. Thus register a special annotation to repack args into a tuple and a normalizer to normalize this tuple.
1 parent c75a8d5 commit ff394e9

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

torch_np/_funcs.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
DTypeLike = typing.TypeVar("DTypeLike")
1414
SubokLike = typing.TypeVar("SubokLike")
1515

16+
# annotate e.g. atleast_1d(*arys)
17+
UnpackedSeqArrayLike = typing.TypeVar("UnpackedSeqArrayLike")
18+
1619

1720
import inspect
1821

@@ -53,30 +56,49 @@ def normalize_subok_like(arg, name):
5356
ArrayLike: normalize_array_like,
5457
Optional[ArrayLike]: normalize_optional_array_like,
5558
Sequence[ArrayLike]: normalize_seq_array_like,
59+
UnpackedSeqArrayLike: normalize_seq_array_like, # cf handling in normalize
5660
DTypeLike: normalize_dtype,
5761
SubokLike: normalize_subok_like,
5862
}
5963

6064
import functools
6165

6266

67+
def normalize_this(arg, parm):
68+
"""Normalize arg if a normalizer is registred."""
69+
normalizer = normalizers.get(parm.annotation, None)
70+
if normalizer:
71+
return normalizer(arg)
72+
else:
73+
# untyped arguments pass through
74+
return arg
75+
76+
6377
def normalizer(func):
6478
@functools.wraps(func)
6579
def wrapped(*args, **kwds):
6680
sig = inspect.signature(func)
6781

68-
lst, dct = [], {}
82+
# first, check for *args in positional parameters. Case in point:
83+
# atleast_1d(*arys: UnpackedSequenceArrayLike)
84+
# if found, consume all args into a tuple to normalize as a whole
85+
for j, param in enumerate(sig.parameters.values()):
86+
if param.annotation == UnpackedSeqArrayLike:
87+
if j == 0:
88+
args = (args,)
89+
else:
90+
# args = args[:j] + (args[j:],) would likely work
91+
# not present in numpy codebase, so do not bother just yet.
92+
# NB: branching on j ==0 is to avoid the empty tuple, args[:j]
93+
raise NotImplementedError
94+
break
95+
6996
# loop over positional parameters and actual arguments
97+
lst, dct = [], {}
7098
for arg, (name, parm) in zip(args, sig.parameters.items()):
7199
print(arg, name, parm.annotation)
72-
normalizer = normalizers.get(parm.annotation, None)
73-
if normalizer:
74-
# dct[name] = normalizer(arg, name)
75-
lst.append(normalizer(arg))
76-
else:
77-
# untyped arguments pass through
78-
# dct[name] = arg
79-
lst.append(arg)
100+
lst.append(normalize_this(arg, parm))
101+
80102

81103
# normalize keyword arguments
82104
for name, arg in kwds.items():
@@ -88,11 +110,7 @@ def wrapped(*args, **kwds):
88110

89111
print("kw: ", name, sig.parameters[name].annotation)
90112
parm = sig.parameters[name]
91-
normalizer = normalizers.get(parm.annotation, None)
92-
if normalizer:
93-
dct[name] = normalizer(kwds[name], name)
94-
else:
95-
dct[name] = arg
113+
dct[name] = normalize_this(arg, parm)
96114

97115
ba = sig.bind(*lst, **dct)
98116
ba.apply_defaults()
@@ -113,6 +131,7 @@ def wrapped(*args, **kwds):
113131
# 5. axes : live in _impl or in types? several ways of handling them
114132
# 6. keepdims : peel off, postprocess
115133
# 7. OutLike : normal & keyword-only, peel off, postprocess
134+
# 8. *args
116135

117136
# finally, pass normalized arguments through
118137
result = func(*ba.args, **ba.kwargs)

torch_np/_wrapper.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313

1414
from . import _dtypes, _helpers, _decorators # isort: skip # XXX
1515

16+
### XXX: order the imports DAG
17+
from . _funcs import normalizer, DTypeLike, ArrayLike, UnpackedSeqArrayLike
18+
from typing import Optional, Sequence
19+
20+
1621
# Things to decide on (punt for now)
1722
#
1823
# 1. Q: What are the return types of wrapper functions: plain torch.Tensors or
@@ -63,27 +68,27 @@ def copy(a, order="K", subok=False):
6368
return a.copy(order="C")
6469

6570

66-
def atleast_1d(*arys):
67-
tensors = _helpers.to_tensors(*arys)
68-
res = torch.atleast_1d(tensors)
71+
@normalizer
72+
def atleast_1d(*arys : UnpackedSeqArrayLike):
73+
res = torch.atleast_1d(*arys)
6974
if len(res) == 1:
7075
return asarray(res[0])
7176
else:
7277
return list(asarray(_) for _ in res)
7378

7479

75-
def atleast_2d(*arys):
76-
tensors = _helpers.to_tensors(*arys)
77-
res = torch.atleast_2d(tensors)
80+
@normalizer
81+
def atleast_2d(*arys : UnpackedSeqArrayLike):
82+
res = torch.atleast_2d(*arys)
7883
if len(res) == 1:
7984
return asarray(res[0])
8085
else:
8186
return list(asarray(_) for _ in res)
8287

8388

84-
def atleast_3d(*arys):
85-
tensors = _helpers.to_tensors(*arys)
86-
res = torch.atleast_3d(tensors)
89+
@normalizer
90+
def atleast_3d(*arys : UnpackedSeqArrayLike):
91+
res = torch.atleast_3d(*arys)
8792
if len(res) == 1:
8893
return asarray(res[0])
8994
else:
@@ -108,10 +113,6 @@ def _concat_check(tup, dtype, out):
108113
)
109114

110115

111-
### XXX: order the imports DAG
112-
from . _funcs import normalizer, DTypeLike, ArrayLike
113-
from typing import Optional, Sequence
114-
115116
@normalizer
116117
def concatenate(ar_tuple : Sequence[ArrayLike], axis=0, out=None, dtype: DTypeLike=None, casting="same_kind"):
117118
_concat_check(ar_tuple, dtype, out=out)

0 commit comments

Comments
 (0)