@@ -185,10 +185,20 @@ def assert_equal(actual, desired, err_msg="", verbose=True):
185
185
186
186
"""
187
187
__tracebackhide__ = True # Hide traceback for py.test
188
- try :
189
- return actual == desired
190
- except Exception :
191
- pass
188
+
189
+ num_nones = sum ([actual is None , desired is None ])
190
+ if num_nones == 1 :
191
+ raise AssertionError (f"Not equal: { actual } != { desired } " )
192
+ elif num_nones == 2 :
193
+ return True
194
+ # else, carry on
195
+
196
+ if isinstance (actual , np .DType ) or isinstance (desired , np .DType ):
197
+ result = actual == desired
198
+ if not result :
199
+ raise AssertionError (f"Not equal: { actual } != { desired } " )
200
+ else :
201
+ return True
192
202
193
203
if isinstance (desired , dict ):
194
204
if not isinstance (actual , dict ):
@@ -210,9 +220,6 @@ def assert_equal(actual, desired, err_msg="", verbose=True):
210
220
return assert_array_equal (actual , desired , err_msg , verbose )
211
221
msg = build_err_msg ([actual , desired ], err_msg , verbose = verbose )
212
222
213
- if isinstance (actual , np .DType ) and isinstance (desired , np .DType ):
214
- return actual == desired
215
-
216
223
# Handle complex numbers: separate into real/imag to handle
217
224
# nan/inf/negative zero correctly
218
225
# XXX: catch ValueError for subclasses of ndarray where iscomplex fail
0 commit comments