Skip to content

Commit 37be5c0

Browse files
committed
MAINT: tuple_arrays_from or just an array
`_helpers.tuple_arrays_from` is now localized to the normalizer only
1 parent d0423d7 commit 37be5c0

File tree

4 files changed

+49
-17
lines changed

4 files changed

+49
-17
lines changed

torch_np/_detail/implementations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def tile(tensor, reps):
372372
def atleast_1d(tensors):
373373
result = torch.atleast_1d(*tensors)
374374

375-
# match numpy: return a list not tuple;
375+
# match numpy: return a list not tuple;
376376
# >>> np.atleast_2d(np.arange(3))
377377
# array([[0, 1, 2]]) # a single 2D array
378378
# >>> torch.atleast_2d([torch.arange(3)])

torch_np/_funcs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
normalizer,
1616
)
1717

18-
1918
NoValue = _util.NoValue
2019

2120

torch_np/_normalizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def wrapped(*args, **kwds):
168168
# XXX: change to separate tuple and list normalizers?
169169
return r.__origin__(_helpers.tuple_arrays_from(result))
170170
elif r == NDArrayOrSequence:
171-
# this is a single NDArray or tuple/list of NDArrays, e.g. atleast_1d
171+
# a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d
172172
if isinstance(result, (tuple, list)):
173173
seq = type(result)
174174
return seq(_helpers.tuple_arrays_from(result))

torch_np/_wrapper.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
ArrayLike,
1616
DTypeLike,
1717
NDArray,
18+
NDArrayOrSequence,
1819
SubokLike,
1920
UnpackedSeqArrayLike,
20-
NDArrayOrSequence,
2121
normalizer,
2222
)
2323

@@ -522,13 +522,23 @@ def meshgrid(
522522
return output
523523

524524

525-
@normalizer
526-
def indices(dimensions, dtype: DTypeLike = int, sparse=False):
527-
result = _impl.indices(dimensions, dtype=dtype, sparse=sparse)
525+
def indices(dimensions, dtype=int, sparse=False):
528526
if sparse:
529-
return _helpers.tuple_arrays_from(result)
527+
return _indices_sparse(dimensions, dtype)
530528
else:
531-
return _helpers.array_from(result)
529+
return _indices_full(dimensions, dtype)
530+
531+
532+
@normalizer
533+
def _indices_sparse(dimensions, dtype: DTypeLike = int) -> tuple[NDArray]:
534+
result = _impl.indices(dimensions, dtype=dtype, sparse=True)
535+
return result
536+
537+
538+
@normalizer
539+
def _indices_full(dimensions, dtype: DTypeLike = int) -> NDArray:
540+
result = _impl.indices(dimensions, dtype=dtype, sparse=False)
541+
return result
532542

533543

534544
@normalizer
@@ -581,20 +591,43 @@ def tri(N, M=None, k=0, dtype: DTypeLike = float, *, like: SubokLike = None) ->
581591
###### reductions
582592

583593

594+
def average(a, axis=None, weights=None, returned=False, keepdims=NoValue):
595+
if returned:
596+
return _average_with_ret(
597+
a, axis=axis, weights=weights, returned=True, keepdims=keepdims
598+
)
599+
else:
600+
return _average_without_ret(
601+
a, axis=axis, weights=weights, returned=False, keepdims=keepdims
602+
)
603+
604+
605+
@normalizer
606+
def _average_with_ret(
607+
a: ArrayLike,
608+
axis=None,
609+
weights: Optional[ArrayLike] = None,
610+
returned=True,
611+
*,
612+
keepdims=NoValue,
613+
) -> tuple[NDArray]:
614+
assert returned
615+
result, wsum = _impl.average(a, axis, weights, returned=True, keepdims=keepdims)
616+
return result, wsum
617+
618+
584619
@normalizer
585-
def average(
620+
def _average_without_ret(
586621
a: ArrayLike,
587622
axis=None,
588-
weights: ArrayLike = None,
623+
weights: Optional[ArrayLike] = None,
589624
returned=False,
590625
*,
591626
keepdims=NoValue,
592-
):
593-
result, wsum = _impl.average(a, axis, weights, returned=returned, keepdims=keepdims)
594-
if returned:
595-
return _helpers.tuple_arrays_from((result, wsum))
596-
else:
597-
return _helpers.array_from(result)
627+
) -> NDArray:
628+
assert not returned
629+
result, wsum = _impl.average(a, axis, weights, returned=False, keepdims=keepdims)
630+
return result
598631

599632

600633
@normalizer

0 commit comments

Comments
 (0)