1
1
import numpy as np
2
2
3
3
from pandas import compat
4
- from pandas.core.common import isnull, array_equivalent
4
+ from pandas.core.common import isnull, array_equivalent, is_dtype_equal
5
5
6
6
cdef NUMERIC_TYPES = (
7
7
bool ,
@@ -55,7 +55,7 @@ cpdef assert_dict_equal(a, b, bint compare_keys=True):
55
55
56
56
return True
57
57
58
- cpdef assert_almost_equal(a, b, bint check_less_precise = False ,
58
+ cpdef assert_almost_equal(a, b, bint check_less_precise = False , check_dtype = True ,
59
59
obj = None , lobj = None , robj = None ):
60
60
""" Check that left and right objects are almost equal.
61
61
@@ -66,6 +66,8 @@ cpdef assert_almost_equal(a, b, bint check_less_precise=False,
66
66
check_less_precise : bool, default False
67
67
Specify comparison precision.
68
68
5 digits (False) or 3 digits (True) after decimal points are compared.
69
+ check_dtype: bool, default True
70
+ check dtype if both a and b are np.ndarray
69
71
obj : str, default None
70
72
Specify object name being compared, internally used to show appropriate
71
73
assertion message
@@ -82,7 +84,7 @@ cpdef assert_almost_equal(a, b, bint check_less_precise=False,
82
84
double diff = 0.0
83
85
Py_ssize_t i, na, nb
84
86
double fa, fb
85
- bint is_unequal = False
87
+ bint is_unequal = False , a_is_ndarray, b_is_ndarray
86
88
87
89
if lobj is None :
88
90
lobj = a
@@ -97,36 +99,43 @@ cpdef assert_almost_equal(a, b, bint check_less_precise=False,
97
99
assert a == b, " %r != %r " % (a, b)
98
100
return True
99
101
102
+ a_is_ndarray = isinstance (a, np.ndarray)
103
+ b_is_ndarray = isinstance (b, np.ndarray)
104
+
105
+ if obj is None :
106
+ if a_is_ndarray or b_is_ndarray:
107
+ obj = ' numpy array'
108
+ else :
109
+ obj = ' Iterable'
110
+
100
111
if isiterable(a):
101
112
102
113
if not isiterable(b):
103
- from pandas.util.testing import raise_assert_detail
104
- if obj is None :
105
- obj = ' Iterable'
106
- msg = " First object is iterable, second isn't"
107
- raise_assert_detail(obj, msg, a, b)
114
+ from pandas.util.testing import assert_class_equal
115
+ # classes can't be the same, to raise error
116
+ assert_class_equal(a, b, obj = obj)
108
117
109
118
assert has_length(a) and has_length(b), (
110
119
" Can't compare objects without length, one or both is invalid: "
111
- " (%r , %r )" % (a, b)
112
- )
120
+ " (%r , %r )" % (a, b))
113
121
114
- if isinstance (a, np.ndarray) and isinstance (b, np.ndarray):
115
- if obj is None :
116
- obj = ' numpy array'
122
+ if a_is_ndarray and b_is_ndarray:
117
123
na, nb = a.size, b.size
118
124
if a.shape != b.shape:
119
125
from pandas.util.testing import raise_assert_detail
120
126
raise_assert_detail(obj, ' {0} shapes are different' .format(obj),
121
127
a.shape, b.shape)
128
+
129
+ if check_dtype and not is_dtype_equal(a, b):
130
+ from pandas.util.testing import assert_attr_equal
131
+ assert_attr_equal(' dtype' , a, b, obj = obj)
132
+
122
133
try :
123
134
if array_equivalent(a, b, strict_nan = True ):
124
135
return True
125
136
except :
126
137
pass
127
138
else :
128
- if obj is None :
129
- obj = ' Iterable'
130
139
na, nb = len (a), len (b)
131
140
132
141
if na != nb:
@@ -149,54 +158,38 @@ cpdef assert_almost_equal(a, b, bint check_less_precise=False,
149
158
return True
150
159
151
160
elif isiterable(b):
152
- from pandas.util.testing import raise_assert_detail
153
- if obj is None :
154
- obj = ' Iterable'
155
- msg = " Second object is iterable, first isn't"
156
- raise_assert_detail(obj, msg, a, b)
161
+ from pandas.util.testing import assert_class_equal
162
+ # classes can't be the same, to raise error
163
+ assert_class_equal(a, b, obj = obj)
157
164
158
- if isnull(a):
159
- assert isnull(b), (
160
- " First object is null, second isn't: %r != %r " % (a, b)
161
- )
165
+ if a == b:
166
+ # object comparison
162
167
return True
163
- elif isnull(b):
164
- assert isnull(a), (
165
- " First object is not null, second is null: %r != %r " % (a, b)
166
- )
168
+ if isnull(a) and isnull(b):
169
+ # nan / None comparison
167
170
return True
168
-
169
- if is_comparable_as_number(a):
170
- assert is_comparable_as_number(b), (
171
- " First object is numeric, second is not: %r != %r " % (a, b)
172
- )
171
+ if is_comparable_as_number(a) and is_comparable_as_number(b):
172
+ if array_equivalent(a, b, strict_nan = True ):
173
+ # inf comparison
174
+ return True
173
175
174
176
decimal = 5
175
177
176
178
# deal with differing dtypes
177
179
if check_less_precise:
178
180
decimal = 3
179
181
180
- if np.isinf(a):
181
- assert np.isinf(b), " First object is inf, second isn't"
182
- if np.isposinf(a):
183
- assert np.isposinf(b), " First object is positive inf, second is negative inf"
184
- else :
185
- assert np.isneginf(b), " First object is negative inf, second is positive inf"
182
+ fa, fb = a, b
183
+
184
+ # case for zero
185
+ if abs (fa) < 1e-5 :
186
+ if not decimal_almost_equal(fa, fb, decimal):
187
+ assert False , (
188
+ ' (very low values) expected %.5f but got %.5f , with decimal %d ' % (fb, fa, decimal)
189
+ )
186
190
else :
187
- fa, fb = a, b
188
-
189
- # case for zero
190
- if abs (fa) < 1e-5 :
191
- if not decimal_almost_equal(fa, fb, decimal):
192
- assert False , (
193
- ' (very low values) expected %.5f but got %.5f , with decimal %d ' % (fb, fa, decimal)
194
- )
195
- else :
196
- if not decimal_almost_equal(1 , fb / fa, decimal):
197
- assert False , ' expected %.5f but got %.5f , with decimal %d ' % (fb, fa, decimal)
198
-
199
- else :
200
- assert a == b, " %r != %r " % (a, b)
191
+ if not decimal_almost_equal(1 , fb / fa, decimal):
192
+ assert False , ' expected %.5f but got %.5f , with decimal %d ' % (fb, fa, decimal)
193
+ return True
201
194
202
- return True
195
+ raise AssertionError ( " {0} != {1} " .format(a, b))
0 commit comments