Skip to content

Commit bc9448e

Browse files
committed
MAINT: remove to_tensors_or_none, use Optional[ArrayLike] instead
1 parent d445dd0 commit bc9448e

File tree

3 files changed

+29
-46
lines changed

3 files changed

+29
-46
lines changed

torch_np/_helpers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,3 @@ def to_tensors(*inputs):
118118

119119
return tuple(asarray(value).get() for value in inputs)
120120

121-
122-
def to_tensors_or_none(*inputs):
123-
"""Convert all array_likes from `inputs` to tensors. Nones pass through"""
124-
from ._ndarray import asarray, ndarray
125-
126-
return tuple(None if value is None else asarray(value).get() for value in inputs)

torch_np/_wrapper.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def _concat_check(tup, dtype, out):
109109

110110

111111
### XXX: order the imports DAG
112-
from . _funcs import normalizer, DTypeLike
112+
from . _funcs import normalizer, DTypeLike, ArrayLike
113+
from typing import Optional
113114

114115
@normalizer
115116
def concatenate(ar_tuple, axis=0, out=None, dtype: DTypeLike=None, casting="same_kind"):
@@ -368,55 +369,46 @@ def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True):
368369
return x_tensor
369370

370371

371-
@_decorators.dtype_to_torch
372-
def corrcoef(x, y=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype=None):
372+
#@_decorators.dtype_to_torch
373+
@normalizer
374+
def corrcoef(x : ArrayLike, y : Optional[ArrayLike]=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype : DTypeLike=None):
373375
if bias is not None or ddof is not None:
374376
# deprecated in NumPy
375377
raise NotImplementedError
376-
377-
x_tensor, y_tensor = _helpers.to_tensors_or_none(x, y)
378-
tensor = _xy_helper_corrcoef(x_tensor, y_tensor, rowvar)
378+
tensor = _xy_helper_corrcoef(x, y, rowvar)
379379
result = _impl.corrcoef(tensor, dtype=dtype)
380380
return asarray(result)
381381

382382

383-
@_decorators.dtype_to_torch
383+
@normalizer
384384
def cov(
385-
m,
386-
y=None,
385+
m : ArrayLike,
386+
y : Optional[ArrayLike]=None,
387387
rowvar=True,
388388
bias=False,
389389
ddof=None,
390-
fweights=None,
391-
aweights=None,
390+
fweights : Optional[ArrayLike]=None,
391+
aweights : Optional[ArrayLike]=None,
392392
*,
393-
dtype=None,
393+
dtype : DTypeLike=None,
394394
):
395-
396-
m_tensor, y_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none(
397-
m, y, fweights, aweights
398-
)
399-
m_tensor = _xy_helper_corrcoef(m_tensor, y_tensor, rowvar)
400-
401-
result = _impl.cov(
402-
m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype
403-
)
395+
m = _xy_helper_corrcoef(m, y, rowvar)
396+
result = _impl.cov(m, bias, ddof, fweights, aweights, dtype=dtype)
404397
return asarray(result)
405398

406399

407-
def bincount(x, /, weights=None, minlength=0):
408-
if not isinstance(x, ndarray) and x == []:
400+
@normalizer
401+
def bincount(x : ArrayLike, /, weights : Optional[ArrayLike]=None, minlength=0):
402+
if x.numel() == 0:
409403
# edge case allowed by numpy
410-
x = asarray([], dtype=int)
411-
412-
x_tensor, weights_tensor = _helpers.to_tensors_or_none(x, weights)
413-
result = _impl.bincount(x_tensor, weights_tensor, minlength)
404+
x = torch.as_tensor([], dtype=int)
405+
result = _impl.bincount(x, weights, minlength)
414406
return asarray(result)
415407

416408

417-
def where(condition, x=None, y=None, /):
418-
cond_t, x_t, y_t = _helpers.to_tensors_or_none(condition, x, y)
419-
result = _impl.where(cond_t, x_t, y_t)
409+
@normalizer
410+
def where(condition : ArrayLike, x : Optional[ArrayLike]=None, y: Optional[ArrayLike]=None, /):
411+
result = _impl.where(condition, x, y)
420412
if isinstance(result, tuple):
421413
# single-argument where(condition)
422414
return tuple(asarray(x) for x in result)
@@ -840,22 +832,19 @@ def nanpercentile():
840832
raise NotImplementedError
841833

842834

843-
def diff(a, n=1, axis=-1, prepend=NoValue, append=NoValue):
835+
@normalizer
836+
def diff(a : ArrayLike, n=1, axis=-1, prepend : Optional[ArrayLike]=NoValue, append : Optional[ArrayLike]=NoValue):
844837

845838
if n == 0:
846839
# match numpy and return the input immediately
847-
return a
848-
849-
a_tensor, prepend_tensor, append_tensor = _helpers.to_tensors_or_none(
850-
a, prepend, append
851-
)
840+
return asarray(a)
852841

853842
result = _impl.diff(
854-
a_tensor,
843+
a,
855844
n=n,
856845
axis=axis,
857-
prepend_tensor=prepend_tensor,
858-
append_tensor=append_tensor,
846+
prepend_tensor=prepend,
847+
append_tensor=append,
859848
)
860849
return asarray(result)
861850

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ def test_n(self):
732732
assert_raises(ValueError, diff, x, n=-1)
733733
output = [diff(x, n=n) for n in range(1, 5)]
734734
expected = [[1, 1], [0], [], []]
735-
assert_(diff(x, n=0) is x)
735+
## assert_(diff(x, n=0) is x)
736736
for n, (expected, out) in enumerate(zip(expected, output), start=1):
737737
assert_(type(out) is np.ndarray)
738738
assert_array_equal(out, expected)

0 commit comments

Comments
 (0)