Skip to content

Commit 3aa3134

Browse files
committed
MAINT: simplify handling of variadic *args in normalize
1 parent 8c78725 commit 3aa3134

File tree

3 files changed

+24
-37
lines changed

3 files changed

+24
-37
lines changed

torch_np/_funcs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
DTypeLike,
1212
NDArray,
1313
SubokLike,
14-
UnpackedSeqArrayLike,
1514
normalizer,
1615
)
1716

torch_np/_normalizations.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
AxisLike = typing.TypeVar("AxisLike")
1515
NDArray = typing.TypeVar("NDarray")
1616

17-
# annotate e.g. atleast_1d(*arys)
18-
UnpackedSeqArrayLike = typing.TypeVar("UnpackedSeqArrayLike")
19-
2017

2118
import inspect
2219

@@ -76,7 +73,6 @@ def normalize_ndarray(arg, name=None):
7673
ArrayLike: normalize_array_like,
7774
Optional[ArrayLike]: normalize_optional_array_like,
7875
Sequence[ArrayLike]: normalize_seq_array_like,
79-
UnpackedSeqArrayLike: normalize_seq_array_like, # cf handling in normalize
8076
Optional[NDArray]: normalize_ndarray,
8177
DTypeLike: normalize_dtype,
8278
SubokLike: normalize_subok_like,
@@ -100,25 +96,20 @@ def normalizer(func):
10096
@functools.wraps(func)
10197
def wrapped(*args, **kwds):
10298
sig = inspect.signature(func)
103-
104-
# first, check for *args in positional parameters. Case in point:
105-
# atleast_1d(*arys: UnpackedSequenceArrayLike)
106-
# if found, consume all args into a tuple to normalize as a whole
107-
for j, param in enumerate(sig.parameters.values()):
108-
if param.annotation == UnpackedSeqArrayLike:
109-
if j == 0:
110-
args = (args,)
111-
else:
112-
# args = args[:j] + (args[j:],) would likely work
113-
# not present in numpy codebase, so do not bother just yet.
114-
# NB: branching on j ==0 is to avoid the empty tuple, args[:j]
115-
raise NotImplementedError
99+
sp = dict(sig.parameters)
100+
101+
# check for *args. If detected, duplicate the correspoding parameter
102+
# to have len(args) annotations for each element of *args.
103+
for j, param in enumerate(sp.values()):
104+
if param.kind == inspect.Parameter.VAR_POSITIONAL:
105+
sp.pop(param.name)
106+
variadic = {param.name + str(i): param for i in range(len(args))}
107+
variadic.update(sp)
108+
sp = variadic
116109
break
117110

118111
# normalize positional and keyword arguments
119112
# NB: extra unknown arguments: pass through, will raise in func(*lst) below
120-
sp = sig.parameters
121-
122113
lst = [normalize_this(arg, parm) for arg, parm in zip(args, sp.values())]
123114
lst += args[len(lst) :]
124115

torch_np/_wrapper.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
DTypeLike,
1818
NDArray,
1919
SubokLike,
20-
UnpackedSeqArrayLike,
2120
normalizer,
2221
)
2322

@@ -71,30 +70,30 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
7170

7271

7372
@normalizer
74-
def atleast_1d(*arys: UnpackedSeqArrayLike):
73+
def atleast_1d(*arys: ArrayLike):
7574
res = torch.atleast_1d(*arys)
76-
if len(res) == 1:
77-
return _helpers.array_from(res[0])
78-
else:
75+
if isinstance(res, tuple):
7976
return list(_helpers.tuple_arrays_from(res))
77+
else:
78+
return _helpers.array_from(res)
8079

8180

8281
@normalizer
83-
def atleast_2d(*arys: UnpackedSeqArrayLike):
82+
def atleast_2d(*arys: ArrayLike):
8483
res = torch.atleast_2d(*arys)
85-
if len(res) == 1:
86-
return _helpers.array_from(res[0])
87-
else:
84+
if isinstance(res, tuple):
8885
return list(_helpers.tuple_arrays_from(res))
86+
else:
87+
return _helpers.array_from(res)
8988

9089

9190
@normalizer
92-
def atleast_3d(*arys: UnpackedSeqArrayLike):
91+
def atleast_3d(*arys: ArrayLike):
9392
res = torch.atleast_3d(*arys)
94-
if len(res) == 1:
95-
return _helpers.array_from(res[0])
96-
else:
93+
if isinstance(res, tuple):
9794
return list(_helpers.tuple_arrays_from(res))
95+
else:
96+
return _helpers.array_from(res)
9897

9998

10099
def _concat_check(tup, dtype, out):
@@ -537,8 +536,7 @@ def broadcast_to(array: ArrayLike, shape, subok: SubokLike = False):
537536

538537
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
539538
@normalizer
540-
def broadcast_arrays(*args: UnpackedSeqArrayLike, subok: SubokLike = False):
541-
args = args[0] # undo the *args wrapping in normalizer
539+
def broadcast_arrays(*args: ArrayLike, subok: SubokLike = False):
542540
res = torch.broadcast_tensors(*args)
543541
return _helpers.tuple_arrays_from(res)
544542

@@ -565,8 +563,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
565563

566564

567565
@normalizer
568-
def meshgrid(*xi: UnpackedSeqArrayLike, copy=True, sparse=False, indexing="xy"):
569-
xi = xi[0] # undo the *xi wrapping in normalizer
566+
def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"):
570567
output = _impl.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)
571568
outp = _helpers.tuple_arrays_from(output)
572569
return list(outp) # match numpy, return a list

0 commit comments

Comments
 (0)