Skip to content

Commit 9d75cab

Browse files
committed
MAINT: simplify handling of variadic *args in normalize
1 parent 8c78725 commit 9d75cab

File tree

3 files changed

+25
-44
lines changed

3 files changed

+25
-44
lines changed

torch_np/_funcs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
DTypeLike,
1212
NDArray,
1313
SubokLike,
14-
UnpackedSeqArrayLike,
1514
normalizer,
1615
)
1716

torch_np/_normalizations.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
AxisLike = typing.TypeVar("AxisLike")
1515
NDArray = typing.TypeVar("NDarray")
1616

17-
# annotate e.g. atleast_1d(*arys)
18-
UnpackedSeqArrayLike = typing.TypeVar("UnpackedSeqArrayLike")
19-
2017

2118
import inspect
2219

@@ -76,7 +73,6 @@ def normalize_ndarray(arg, name=None):
7673
ArrayLike: normalize_array_like,
7774
Optional[ArrayLike]: normalize_optional_array_like,
7875
Sequence[ArrayLike]: normalize_seq_array_like,
79-
UnpackedSeqArrayLike: normalize_seq_array_like, # cf handling in normalize
8076
Optional[NDArray]: normalize_ndarray,
8177
DTypeLike: normalize_dtype,
8278
SubokLike: normalize_subok_like,
@@ -100,25 +96,20 @@ def normalizer(func):
10096
@functools.wraps(func)
10197
def wrapped(*args, **kwds):
10298
sig = inspect.signature(func)
103-
104-
# first, check for *args in positional parameters. Case in point:
105-
# atleast_1d(*arys: UnpackedSequenceArrayLike)
106-
# if found, consume all args into a tuple to normalize as a whole
107-
for j, param in enumerate(sig.parameters.values()):
108-
if param.annotation == UnpackedSeqArrayLike:
109-
if j == 0:
110-
args = (args,)
111-
else:
112-
# args = args[:j] + (args[j:],) would likely work
113-
# not present in numpy codebase, so do not bother just yet.
114-
# NB: branching on j ==0 is to avoid the empty tuple, args[:j]
115-
raise NotImplementedError
99+
sp = dict(sig.parameters)
100+
101+
# check for *args. If detected, duplicate the correspoding parameter
102+
# to have len(args) annotations for each element of *args.
103+
for j, param in enumerate(sp.values()):
104+
if param.kind == inspect.Parameter.VAR_POSITIONAL:
105+
sp.pop(param.name)
106+
variadic = {param.name + str(i): param for i in range(len(args))}
107+
variadic.update(sp)
108+
sp = variadic
116109
break
117110

118111
# normalize positional and keyword arguments
119112
# NB: extra unknown arguments: pass through, will raise in func(*lst) below
120-
sp = sig.parameters
121-
122113
lst = [normalize_this(arg, parm) for arg, parm in zip(args, sp.values())]
123114
lst += args[len(lst) :]
124115

torch_np/_wrapper.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,7 @@
1212
from ._detail import _dtypes_impl, _flips, _reductions, _util
1313
from ._detail import implementations as _impl
1414
from ._ndarray import array, asarray, maybe_set_base, ndarray
15-
from ._normalizations import (
16-
ArrayLike,
17-
DTypeLike,
18-
NDArray,
19-
SubokLike,
20-
UnpackedSeqArrayLike,
21-
normalizer,
22-
)
15+
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
2316

2417
# Things to decide on (punt for now)
2518
#
@@ -71,30 +64,30 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
7164

7265

7366
@normalizer
74-
def atleast_1d(*arys: UnpackedSeqArrayLike):
67+
def atleast_1d(*arys: ArrayLike):
7568
res = torch.atleast_1d(*arys)
76-
if len(res) == 1:
77-
return _helpers.array_from(res[0])
78-
else:
69+
if isinstance(res, tuple):
7970
return list(_helpers.tuple_arrays_from(res))
71+
else:
72+
return _helpers.array_from(res)
8073

8174

8275
@normalizer
83-
def atleast_2d(*arys: UnpackedSeqArrayLike):
76+
def atleast_2d(*arys: ArrayLike):
8477
res = torch.atleast_2d(*arys)
85-
if len(res) == 1:
86-
return _helpers.array_from(res[0])
87-
else:
78+
if isinstance(res, tuple):
8879
return list(_helpers.tuple_arrays_from(res))
80+
else:
81+
return _helpers.array_from(res)
8982

9083

9184
@normalizer
92-
def atleast_3d(*arys: UnpackedSeqArrayLike):
85+
def atleast_3d(*arys: ArrayLike):
9386
res = torch.atleast_3d(*arys)
94-
if len(res) == 1:
95-
return _helpers.array_from(res[0])
96-
else:
87+
if isinstance(res, tuple):
9788
return list(_helpers.tuple_arrays_from(res))
89+
else:
90+
return _helpers.array_from(res)
9891

9992

10093
def _concat_check(tup, dtype, out):
@@ -537,8 +530,7 @@ def broadcast_to(array: ArrayLike, shape, subok: SubokLike = False):
537530

538531
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
539532
@normalizer
540-
def broadcast_arrays(*args: UnpackedSeqArrayLike, subok: SubokLike = False):
541-
args = args[0] # undo the *args wrapping in normalizer
533+
def broadcast_arrays(*args: ArrayLike, subok: SubokLike = False):
542534
res = torch.broadcast_tensors(*args)
543535
return _helpers.tuple_arrays_from(res)
544536

@@ -565,8 +557,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
565557

566558

567559
@normalizer
568-
def meshgrid(*xi: UnpackedSeqArrayLike, copy=True, sparse=False, indexing="xy"):
569-
xi = xi[0] # undo the *xi wrapping in normalizer
560+
def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"):
570561
output = _impl.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)
571562
outp = _helpers.tuple_arrays_from(output)
572563
return list(outp) # match numpy, return a list

0 commit comments

Comments
 (0)