Skip to content

Commit d0423d7

Browse files
committed
MAINT: return annotations for tuples of NDArrays
Two cases - just {tuple, list}[NDArrays] (list is e.g. meshgrid) - a variadic return of a single NDArray or a list/tuple (where, atleast_{1,2,3}d, unique) Note that in the latter case the choice of the return type depends purely on the number of input array_likes.
1 parent fb65599 commit d0423d7

File tree

4 files changed

+102
-64
lines changed

4 files changed

+102
-64
lines changed

torch_np/_detail/implementations.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,36 @@ def tile(tensor, reps):
369369
return result
370370

371371

372+
def atleast_1d(tensors):
373+
result = torch.atleast_1d(*tensors)
374+
375+
# match numpy: return a list not tuple;
376+
# >>> np.atleast_2d(np.arange(3))
377+
# array([[0, 1, 2]]) # a single 2D array
378+
# >>> torch.atleast_2d([torch.arange(3)])
379+
# (tensor([[0, 1, 2]]), ) # 1-element tuple of a 2D tensor
380+
if isinstance(result, tuple):
381+
return list(result)
382+
else:
383+
return result
384+
385+
386+
def atleast_2d(tensors):
387+
result = torch.atleast_2d(*tensors)
388+
if isinstance(result, tuple):
389+
return list(result)
390+
else:
391+
return result
392+
393+
394+
def atleast_3d(tensors):
395+
result = torch.atleast_3d(*tensors)
396+
if isinstance(result, tuple):
397+
return list(result)
398+
else:
399+
return result
400+
401+
372402
# #### cov & corrcoef
373403

374404

torch_np/_funcs.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,22 @@
1515
normalizer,
1616
)
1717

18+
1819
NoValue = _util.NoValue
1920

2021

2122
@normalizer
22-
def nonzero(a: ArrayLike):
23+
def nonzero(a: ArrayLike) -> tuple[NDArray]:
2324
result = a.nonzero(as_tuple=True)
24-
return _helpers.tuple_arrays_from(result)
25+
return result
2526

2627

2728
@normalizer
2829
def argwhere(a: ArrayLike) -> NDArray:
2930
result = torch.argwhere(a)
3031
return result
3132

33+
3234
@normalizer
3335
def flatnonzero(a: ArrayLike) -> NDArray:
3436
result = a.ravel().nonzero(as_tuple=True)[0]
@@ -83,7 +85,9 @@ def trace(
8385

8486

8587
@normalizer
86-
def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike = None) -> NDArray:
88+
def eye(
89+
N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike = None
90+
) -> NDArray:
8791
if order != "C":
8892
raise NotImplementedError
8993
result = _impl.eye(N, M, k, dtype)
@@ -108,15 +112,16 @@ def diagflat(v: ArrayLike, k=0) -> NDArray:
108112
return result
109113

110114

111-
def diag_indices(n, ndim=2):
115+
@normalizer
116+
def diag_indices(n, ndim=2) -> tuple[NDArray]:
112117
result = _impl.diag_indices(n, ndim)
113-
return _helpers.tuple_arrays_from(result)
118+
return result
114119

115120

116121
@normalizer
117-
def diag_indices_from(arr: ArrayLike):
122+
def diag_indices_from(arr: ArrayLike) -> tuple[NDArray]:
118123
result = _impl.diag_indices_from(arr)
119-
return _helpers.tuple_arrays_from(result)
124+
return result
120125

121126

122127
@normalizer

torch_np/_normalizations.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
import operator
44
import typing
5-
from typing import Optional, Sequence
5+
from typing import Optional, Sequence, Union
66

77
import torch
88

@@ -18,6 +18,10 @@
1818
UnpackedSeqArrayLike = typing.TypeVar("UnpackedSeqArrayLike")
1919

2020

21+
# return value of atleast_1d et al: single array of a list/tuple of arrays
22+
NDArrayOrSequence = Union[NDArray, Sequence[NDArray]]
23+
24+
2125
import inspect
2226

2327
from . import _dtypes
@@ -154,11 +158,23 @@ def wrapped(*args, **kwds):
154158
result = func(*ba.args, **ba.kwargs)
155159

156160
# handle returns
157-
return_annotation = sig.return_annotation
158-
if return_annotation == NDArray:
161+
r = sig.return_annotation
162+
if r == NDArray:
159163
return _helpers.array_from(result)
160-
elif return_annotation == inspect._empty:
164+
elif r == inspect._empty:
161165
return result
166+
elif hasattr(r, "__origin__") and r.__origin__ in (list, tuple):
167+
# this is tuple[NDArray] or list[NDArray]
168+
# XXX: change to separate tuple and list normalizers?
169+
return r.__origin__(_helpers.tuple_arrays_from(result))
170+
elif r == NDArrayOrSequence:
171+
# this is a single NDArray or tuple/list of NDArrays, e.g. atleast_1d
172+
if isinstance(result, (tuple, list)):
173+
seq = type(result)
174+
return seq(_helpers.tuple_arrays_from(result))
175+
else:
176+
return _helpers.array_from(result)
177+
162178
else:
163179
raise ValueError(f"Unknown return annotation {return_annotation}")
164180

torch_np/_wrapper.py

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
NDArray,
1818
SubokLike,
1919
UnpackedSeqArrayLike,
20+
NDArrayOrSequence,
2021
normalizer,
2122
)
2223

@@ -69,30 +70,21 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False) -> NDArray:
6970

7071

7172
@normalizer
72-
def atleast_1d(*arys: UnpackedSeqArrayLike):
73-
res = torch.atleast_1d(*arys)
74-
if len(res) == 1:
75-
return _helpers.array_from(res[0])
76-
else:
77-
return list(_helpers.tuple_arrays_from(res))
73+
def atleast_1d(*arys: UnpackedSeqArrayLike) -> NDArrayOrSequence:
74+
result = _impl.atleast_1d(*arys)
75+
return result
7876

7977

8078
@normalizer
81-
def atleast_2d(*arys: UnpackedSeqArrayLike):
82-
res = torch.atleast_2d(*arys)
83-
if len(res) == 1:
84-
return _helpers.array_from(res[0])
85-
else:
86-
return list(_helpers.tuple_arrays_from(res))
79+
def atleast_2d(*arys: UnpackedSeqArrayLike) -> NDArrayOrSequence:
80+
result = _impl.atleast_2d(*arys)
81+
return result
8782

8883

8984
@normalizer
90-
def atleast_3d(*arys: UnpackedSeqArrayLike):
91-
res = torch.atleast_3d(*arys)
92-
if len(res) == 1:
93-
return _helpers.array_from(res[0])
94-
else:
95-
return list(_helpers.tuple_arrays_from(res))
85+
def atleast_3d(*arys: UnpackedSeqArrayLike) -> NDArrayOrSequence:
86+
result = _impl.atleast_3d(*arys)
87+
return result
9688

9789

9890
def _concat_check(tup, dtype, out):
@@ -175,33 +167,33 @@ def stack(
175167

176168

177169
@normalizer
178-
def array_split(ary: ArrayLike, indices_or_sections, axis=0):
170+
def array_split(ary: ArrayLike, indices_or_sections, axis=0) -> tuple[NDArray]:
179171
result = _impl.split_helper(ary, indices_or_sections, axis)
180-
return _helpers.tuple_arrays_from(result)
172+
return result
181173

182174

183175
@normalizer
184-
def split(ary: ArrayLike, indices_or_sections, axis=0):
176+
def split(ary: ArrayLike, indices_or_sections, axis=0) -> tuple[NDArray]:
185177
result = _impl.split_helper(ary, indices_or_sections, axis, strict=True)
186-
return _helpers.tuple_arrays_from(result)
178+
return result
187179

188180

189181
@normalizer
190-
def hsplit(ary: ArrayLike, indices_or_sections):
182+
def hsplit(ary: ArrayLike, indices_or_sections) -> tuple[NDArray]:
191183
result = _impl.hsplit(ary, indices_or_sections)
192-
return _helpers.tuple_arrays_from(result)
184+
return result
193185

194186

195187
@normalizer
196-
def vsplit(ary: ArrayLike, indices_or_sections):
188+
def vsplit(ary: ArrayLike, indices_or_sections) -> tuple[NDArray]:
197189
result = _impl.vsplit(ary, indices_or_sections)
198-
return _helpers.tuple_arrays_from(result)
190+
return result
199191

200192

201193
@normalizer
202-
def dsplit(ary: ArrayLike, indices_or_sections):
194+
def dsplit(ary: ArrayLike, indices_or_sections) -> tuple[NDArray]:
203195
result = _impl.dsplit(ary, indices_or_sections)
204-
return _helpers.tuple_arrays_from(result)
196+
return result
205197

206198

207199
@normalizer
@@ -421,13 +413,9 @@ def where(
421413
x: Optional[ArrayLike] = None,
422414
y: Optional[ArrayLike] = None,
423415
/,
424-
):
416+
) -> NDArrayOrSequence:
425417
result = _impl.where(condition, x, y)
426-
if isinstance(result, tuple):
427-
# single-argument where(condition)
428-
return _helpers.tuple_arrays_from(result)
429-
else:
430-
return _helpers.array_from(result)
418+
return result
431419

432420

433421
###### module-level queries of object properties
@@ -496,10 +484,12 @@ def broadcast_to(array: ArrayLike, shape, subok: SubokLike = False) -> NDArray:
496484

497485
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
498486
@normalizer
499-
def broadcast_arrays(*args: UnpackedSeqArrayLike, subok: SubokLike = False):
487+
def broadcast_arrays(
488+
*args: UnpackedSeqArrayLike, subok: SubokLike = False
489+
) -> tuple[NDArray]:
500490
args = args[0] # undo the *args wrapping in normalizer
501491
res = torch.broadcast_tensors(*args)
502-
return _helpers.tuple_arrays_from(res)
492+
return res
503493

504494

505495
def unravel_index(indices, shape, order="C"):
@@ -524,11 +514,12 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
524514

525515

526516
@normalizer
527-
def meshgrid(*xi: UnpackedSeqArrayLike, copy=True, sparse=False, indexing="xy"):
517+
def meshgrid(
518+
*xi: UnpackedSeqArrayLike, copy=True, sparse=False, indexing="xy"
519+
) -> list[NDArray]:
528520
xi = xi[0] # undo the *xi wrapping in normalizer
529521
output = _impl.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)
530-
outp = _helpers.tuple_arrays_from(output)
531-
return list(outp) # match numpy, return a list
522+
return output
532523

533524

534525
@normalizer
@@ -559,26 +550,26 @@ def triu(m: ArrayLike, k=0) -> NDArray:
559550
return result
560551

561552

562-
def tril_indices(n, k=0, m=None):
553+
def tril_indices(n, k=0, m=None) -> tuple[NDArray]:
563554
result = _impl.tril_indices(n, k, m)
564-
return _helpers.tuple_arrays_from(result)
555+
return result
565556

566557

567-
def triu_indices(n, k=0, m=None):
558+
def triu_indices(n, k=0, m=None) -> tuple[NDArray]:
568559
result = _impl.triu_indices(n, k, m)
569-
return _helpers.tuple_arrays_from(result)
560+
return result
570561

571562

572563
@normalizer
573-
def tril_indices_from(arr: ArrayLike, k=0):
564+
def tril_indices_from(arr: ArrayLike, k=0) -> tuple[NDArray]:
574565
result = _impl.tril_indices_from(arr, k)
575-
return _helpers.tuple_arrays_from(result)
566+
return result
576567

577568

578569
@normalizer
579-
def triu_indices_from(arr: ArrayLike, k=0):
570+
def triu_indices_from(arr: ArrayLike, k=0) -> tuple[NDArray]:
580571
result = _impl.triu_indices_from(arr, k)
581-
return _helpers.tuple_arrays_from(result)
572+
return result
582573

583574

584575
@normalizer
@@ -859,7 +850,7 @@ def unique(
859850
axis=None,
860851
*,
861852
equal_nan=True,
862-
):
853+
) -> NDArrayOrSequence:
863854
result = _impl.unique(
864855
ar,
865856
return_index=return_index,
@@ -868,11 +859,7 @@ def unique(
868859
axis=axis,
869860
equal_nan=equal_nan,
870861
)
871-
872-
if isinstance(result, tuple):
873-
return _helpers.tuple_arrays_from(result)
874-
else:
875-
return _helpers.array_from(result)
862+
return result
876863

877864

878865
###### mapping from numpy API objects to wrappers from this module ######

0 commit comments

Comments
 (0)