Skip to content

Commit 5806a5e

Browse files
committed
MAINT: split _equal, _isclose etc
1 parent 4331a01 commit 5806a5e

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

torch_np/_detail/implementations.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22

3-
from . import _util
3+
from . import _dtypes_impl, _util
4+
5+
# ### equality, equivalence, allclose ###
46

57

68
def tensor_equal(a1_t, a2_t, equal_nan=False):
@@ -19,6 +21,27 @@ def tensor_equal(a1_t, a2_t, equal_nan=False):
1921
return bool(result.all())
2022

2123

24+
def tensor_equiv(a1_t, a2_t):
25+
# *almost* the same as tensor_equal: _equiv tries to broadcast, _equal does not
26+
try:
27+
a1_t, a2_t = torch.broadcast_tensors(a1_t, a2_t)
28+
except RuntimeError:
29+
# failed to broadcast => not equivalent
30+
return False
31+
return tensor_equal(a1_t, a2_t)
32+
33+
34+
def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
35+
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
36+
a = a.to(dtype)
37+
b = b.to(dtype)
38+
result = torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
39+
return result
40+
41+
42+
# ### splits ###
43+
44+
2245
def split_helper(tensor, indices_or_sections, axis, strict=False):
2346
if isinstance(indices_or_sections, int):
2447
return split_helper_int(tensor, indices_or_sections, axis, strict)

torch_np/_wrapper.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,17 +1107,15 @@ def isscalar(a):
11071107

11081108

11091109
def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
1110-
a, b = _helpers.to_tensors(a, b)
1111-
dtype = result_type(a, b)
1112-
torch_dtype = dtype.type.torch_dtype
1113-
a = a.to(torch_dtype)
1114-
b = b.to(torch_dtype)
1115-
return asarray(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
1110+
a_t, b_t = _helpers.to_tensors(a, b)
1111+
result = _impl.tensor_isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan)
1112+
return asarray(result)
11161113

11171114

11181115
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
1119-
arr_res = isclose(a, b, rtol, atol, equal_nan)
1120-
return arr_res.all()
1116+
a_t, b_t = _helpers.to_tensors(a, b)
1117+
result = _impl.tensor_isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan)
1118+
return result.all()
11211119

11221120

11231121
def array_equal(a1, a2, equal_nan=False):
@@ -1128,12 +1126,8 @@ def array_equal(a1, a2, equal_nan=False):
11281126

11291127
def array_equiv(a1, a2):
11301128
a1_t, a2_t = _helpers.to_tensors(a1, a2)
1131-
try:
1132-
a1_t, a2_t = torch.broadcast_tensors(a1_t, a2_t)
1133-
except RuntimeError:
1134-
# failed to broadcast => not equivalent
1135-
return False
1136-
return _impl.tensor_equal(a1_t, a2_t)
1129+
result = _impl.tensor_equiv(a1_t, a2_t)
1130+
return result
11371131

11381132

11391133
###### mapping from numpy API objects to wrappers from this module ######

0 commit comments

Comments
 (0)