Skip to content

Commit 011ecbf

Browse files
committed
MAINT: address further review comments
1 parent f9f9ba4 commit 011ecbf

File tree

2 files changed

+6
-14
lines changed

2 files changed

+6
-14
lines changed

torch_np/_funcs.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -889,20 +889,14 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False
889889
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
890890

891891

892-
def _tensor_equal(a1_t, a2_t, equal_nan=False):
892+
def _tensor_equal(a1, a2, equal_nan=False):
893893
# Implementation of array_equal/array_equiv.
894-
if a1_t.shape != a2_t.shape:
895-
return False
896894
if equal_nan:
897-
nan_loc = (torch.isnan(a1_t) == torch.isnan(a2_t)).all()
898-
if nan_loc:
899-
# check the values
900-
result = a1_t[~torch.isnan(a1_t)] == a2_t[~torch.isnan(a2_t)]
901-
else:
902-
return False
895+
return (a1.shape == a2.shape) and (
896+
(a1 == a2) | (torch.isnan(a1) & torch.isnan(a2))
897+
).all().item()
903898
else:
904-
result = a1_t == a2_t
905-
return bool(result.all())
899+
return torch.equal(a1, a2)
906900

907901

908902
@normalizer
@@ -1798,8 +1792,6 @@ def i0(x: ArrayLike):
17981792
@normalizer(return_on_failure=False)
17991793
def isscalar(a: ArrayLike):
18001794
# XXX: this is a stub
1801-
if a is False:
1802-
return False
18031795
return a.numel() == 1
18041796

18051797

torch_np/_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def reshape(self, *shape, order="C"):
357357

358358
def sort(self, axis=-1, kind=None, order=None):
359359
# ndarray.sort works in-place
360-
self.tensor = _funcs._sort(self.tensor, axis, kind, order)
360+
self.tensor.copy_(_funcs._sort(self.tensor, axis, kind, order))
361361

362362
argsort = _funcs.argsort
363363
searchsorted = _funcs.searchsorted

0 commit comments

Comments
 (0)