Skip to content

Commit 1b2937e

Browse files
authored
Merge pull request #36 from Quansight-Labs/testing_repr
BUG: testing: fix repr error on failure
2 parents bc9bd74 + 6616c28 commit 1b2937e

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

torch_np/testing/utils.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,10 @@ def build_err_msg(
116116
msg.append(err_msg)
117117
if verbose:
118118
for i, a in enumerate(arrays):
119-
120119
if isinstance(a, ndarray):
121120
# precision argument is only needed if the objects are ndarrays
122121
# r_func = partial(array_repr, precision=precision)
123-
r_func = a.get().__repr__ # XXX
122+
r_func = ndarray.__repr__
124123
else:
125124
r_func = repr
126125

@@ -660,7 +659,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval="nan"):
660659
# do not trigger a failure (np.ma.masked != True evaluates as
661660
# np.ma.masked, which is falsy).
662661
if not cond:
663-
n_mismatch = reduced.size - reduced.sum(dtype=intp)
662+
n_mismatch = reduced.size - int(reduced.sum(dtype=intp))
664663
n_elements = flagged.size if flagged.ndim != 0 else reduced.size
665664
percent_mismatch = 100 * n_mismatch / n_elements
666665
remarks = [
@@ -677,12 +676,9 @@ def func_assert_same_pos(x, y, func=isnan, hasval="nan"):
677676
error2 = abs(y - x)
678677
np.minimum(error, error2, out=error)
679678
max_abs_error = max(error)
680-
if getattr(error, "dtype", object_) == object_:
681-
remarks.append("Max absolute difference: " + str(max_abs_error))
682-
else:
683-
remarks.append(
684-
"Max absolute difference: " + array2string(max_abs_error)
685-
)
679+
remarks.append(
680+
"Max absolute difference: " + array2string(max_abs_error)
681+
)
686682

687683
# note: this definition of relative error matches that one
688684
# used by assert_allclose (found in np.isclose)
@@ -692,12 +688,9 @@ def func_assert_same_pos(x, y, func=isnan, hasval="nan"):
692688
max_rel_error = array(inf)
693689
else:
694690
max_rel_error = max(error[nonzero] / abs(y[nonzero]))
695-
if getattr(error, "dtype", object_) == object_:
696-
remarks.append("Max relative difference: " + str(max_rel_error))
697-
else:
698-
remarks.append(
699-
"Max relative difference: " + array2string(max_rel_error)
700-
)
691+
remarks.append(
692+
"Max relative difference: " + array2string(max_rel_error)
693+
)
701694

702695
err_msg += "\n" + "\n".join(remarks)
703696
msg = build_err_msg(

0 commit comments

Comments
 (0)