Skip to content

Commit 3b50b52

Browse files
committed
TST/PERF: Re-write assert_almost_equal() in cython #4398
Add a testing.pyx cython file, and port assert_almost_equal() from python to cython. On my machine this brings a modest gain to the suite of "not slow" tests (160s -> 140s), but on assert_almost_equal() heavy tests, like test_expressions.py, it shows a large improvement (14s -> 4s).
1 parent 7e3585d commit 3b50b52

File tree

4 files changed

+103
-67
lines changed

4 files changed

+103
-67
lines changed

doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ See :ref:`Internal Refactoring<whatsnew_0130.refactoring>`
414414
compatible. (:issue:`5213`, :issue:`5214`)
415415
- Unity ``dropna`` for Series/DataFrame signature (:issue:`5250`),
416416
tests from :issue:`5234`, courtesy of @rockg
417+
- Rewrite assert_almost_equal() in cython for performance (:issue:`4398`)
417418

418419
.. _release.bug_fixes-0.13.0:
419420

pandas/src/testing.pyx

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import numpy as np
2+
3+
from pandas import compat
4+
from pandas.core.common import isnull
5+
6+
cdef bint isiterable(obj):
7+
return hasattr(obj, '__iter__')
8+
9+
cdef bint decimal_almost_equal(double desired, double actual, int decimal):
10+
# Code from
11+
# http://docs.scipy.org/doc/numpy/reference/generated
12+
# /numpy.testing.assert_almost_equal.html
13+
return abs(desired - actual) < (0.5 * 10.0 ** -decimal)
14+
15+
cpdef assert_dict_equal(a, b, bint compare_keys=True):
16+
a_keys = frozenset(a.keys())
17+
b_keys = frozenset(b.keys())
18+
19+
if compare_keys:
20+
assert a_keys == b_keys
21+
22+
for k in a_keys:
23+
assert_almost_equal(a[k], b[k])
24+
25+
return True
26+
27+
cpdef assert_almost_equal(a, b, bint check_less_precise=False):
28+
cdef:
29+
int decimal
30+
Py_ssize_t i, na, nb
31+
double fa, fb
32+
33+
if isinstance(a, dict) or isinstance(b, dict):
34+
return assert_dict_equal(a, b)
35+
36+
if isinstance(a, compat.string_types):
37+
assert a == b, "%r != %r" % (a, b)
38+
return True
39+
40+
if isiterable(a):
41+
assert isiterable(b), "First object is iterable, second isn't"
42+
na, nb = len(a), len(b)
43+
assert na == nb, "%s != %s" % (na, nb)
44+
if (isinstance(a, np.ndarray) and
45+
isinstance(b, np.ndarray) and
46+
np.array_equal(a, b)):
47+
return True
48+
else:
49+
for i in xrange(na):
50+
assert_almost_equal(a[i], b[i], check_less_precise)
51+
return True
52+
53+
if isnull(a):
54+
assert isnull(b), "First object is null, second isn't"
55+
return True
56+
57+
if isinstance(a, (bool, float, int, np.float32)):
58+
decimal = 5
59+
60+
# deal with differing dtypes
61+
if check_less_precise:
62+
dtype_a = np.dtype(type(a))
63+
dtype_b = np.dtype(type(b))
64+
if dtype_a.kind == 'f' and dtype_b == 'f':
65+
if dtype_a.itemsize <= 4 and dtype_b.itemsize <= 4:
66+
decimal = 3
67+
68+
if np.isinf(a):
69+
assert np.isinf(b), "First object is inf, second isn't"
70+
else:
71+
fa, fb = a, b
72+
73+
# case for zero
74+
if abs(fa) < 1e-5:
75+
if not decimal_almost_equal(fa, fb, decimal):
76+
assert False, (
77+
'(very low values) expected %.5f but got %.5f' % (b, a)
78+
)
79+
else:
80+
if not decimal_almost_equal(1, fb / fa, decimal):
81+
assert False, 'expected %.5f but got %.5f' % (b, a)
82+
83+
else:
84+
assert a == b, "%s != %s" % (a, b)
85+
86+
return True

pandas/util/testing.py

+7-66
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from pandas.tseries.index import DatetimeIndex
3838
from pandas.tseries.period import PeriodIndex
3939

40+
from pandas import _testing
41+
4042
from pandas.io.common import urlopen
4143

4244
Index = index.Index
@@ -50,6 +52,11 @@
5052
K = 4
5153
_RAISE_NETWORK_ERROR_DEFAULT = False
5254

55+
# NOTE: don't pass an NDFrame or index to this function - may not handle it
56+
# well.
57+
assert_almost_equal = _testing.assert_almost_equal
58+
59+
assert_dict_equal = _testing.assert_dict_equal
5360

5461
def randbool(size=(), p=0.5):
5562
return rand(*size) <= p
@@ -374,75 +381,9 @@ def assert_attr_equal(attr, left, right):
374381
def isiterable(obj):
375382
return hasattr(obj, '__iter__')
376383

377-
378-
# NOTE: don't pass an NDFrame or index to this function - may not handle it
379-
# well.
380-
def assert_almost_equal(a, b, check_less_precise=False):
381-
if isinstance(a, dict) or isinstance(b, dict):
382-
return assert_dict_equal(a, b)
383-
384-
if isinstance(a, compat.string_types):
385-
assert a == b, "%r != %r" % (a, b)
386-
return True
387-
388-
if isiterable(a):
389-
np.testing.assert_(isiterable(b))
390-
na, nb = len(a), len(b)
391-
assert na == nb, "%s != %s" % (na, nb)
392-
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray) and\
393-
np.array_equal(a, b):
394-
return True
395-
else:
396-
for i in range(na):
397-
assert_almost_equal(a[i], b[i], check_less_precise)
398-
return True
399-
400-
err_msg = lambda a, b: 'expected %.5f but got %.5f' % (b, a)
401-
402-
if isnull(a):
403-
np.testing.assert_(isnull(b))
404-
return
405-
406-
if isinstance(a, (bool, float, int, np.float32)):
407-
decimal = 5
408-
409-
# deal with differing dtypes
410-
if check_less_precise:
411-
dtype_a = np.dtype(type(a))
412-
dtype_b = np.dtype(type(b))
413-
if dtype_a.kind == 'f' and dtype_b == 'f':
414-
if dtype_a.itemsize <= 4 and dtype_b.itemsize <= 4:
415-
decimal = 3
416-
417-
if np.isinf(a):
418-
assert np.isinf(b), err_msg(a, b)
419-
420-
# case for zero
421-
elif abs(a) < 1e-5:
422-
np.testing.assert_almost_equal(
423-
a, b, decimal=decimal, err_msg=err_msg(a, b), verbose=False)
424-
else:
425-
np.testing.assert_almost_equal(
426-
1, a / b, decimal=decimal, err_msg=err_msg(a, b), verbose=False)
427-
else:
428-
assert a == b, "%s != %s" % (a, b)
429-
430-
431384
def is_sorted(seq):
432385
return assert_almost_equal(seq, np.sort(np.array(seq)))
433386

434-
435-
def assert_dict_equal(a, b, compare_keys=True):
436-
a_keys = frozenset(a.keys())
437-
b_keys = frozenset(b.keys())
438-
439-
if compare_keys:
440-
assert(a_keys == b_keys)
441-
442-
for k in a_keys:
443-
assert_almost_equal(a[k], b[k])
444-
445-
446387
def assert_series_equal(left, right, check_dtype=True,
447388
check_index_type=False,
448389
check_series_type=False,

setup.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ class CheckSDist(sdist):
304304
'pandas/index.pyx',
305305
'pandas/algos.pyx',
306306
'pandas/parser.pyx',
307-
'pandas/src/sparse.pyx']
307+
'pandas/src/sparse.pyx',
308+
'pandas/src/testing.pyx']
308309

309310
def initialize_options(self):
310311
sdist.initialize_options(self)
@@ -464,6 +465,13 @@ def pxd(name):
464465

465466
extensions.extend([sparse_ext])
466467

468+
testing_ext = Extension('pandas._testing',
469+
sources=[srcpath('testing', suffix=suffix)],
470+
include_dirs=[],
471+
libraries=libraries)
472+
473+
extensions.extend([testing_ext])
474+
467475
#----------------------------------------------------------------------
468476
# msgpack stuff here
469477

0 commit comments

Comments
 (0)