Skip to content

Commit 525ef1a

Browse files
committed
BUG: fix assert_equal for nones and dtypes
1 parent 17324f9 commit 525ef1a

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

torch_np/testing/utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,20 @@ def assert_equal(actual, desired, err_msg="", verbose=True):
185185
186186
"""
187187
__tracebackhide__ = True # Hide traceback for py.test
188-
try:
189-
return actual == desired
190-
except Exception:
191-
pass
188+
189+
num_nones = sum([actual is None, desired is None])
190+
if num_nones == 1:
191+
raise AssertionError(f"Not equal: {actual} != {desired}")
192+
elif num_nones == 2:
193+
return True
194+
# else, carry on
195+
196+
if isinstance(actual, np.DType) or isinstance(desired, np.DType):
197+
result = actual == desired
198+
if not result:
199+
raise AssertionError(f"Not equal: {actual} != {desired}")
200+
else:
201+
return True
192202

193203
if isinstance(desired, dict):
194204
if not isinstance(actual, dict):
@@ -210,9 +220,6 @@ def assert_equal(actual, desired, err_msg="", verbose=True):
210220
return assert_array_equal(actual, desired, err_msg, verbose)
211221
msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
212222

213-
if isinstance(actual, np.DType) and isinstance(desired, np.DType):
214-
return actual == desired
215-
216223
# Handle complex numbers: separate into real/imag to handle
217224
# nan/inf/negative zero correctly
218225
# XXX: catch ValueError for subclasses of ndarray where iscomplex fail

0 commit comments

Comments
 (0)