Skip to content

Commit 113511e

Browse files
committed
ENH: add kendall/spearman correlation methods, GH #428
1 parent 5d94e1b commit 113511e

File tree

8 files changed

+142
-53
lines changed

8 files changed

+142
-53
lines changed

RELEASE.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@ pandas 0.6.1
3131

3232
- Can pass Series to DataFrame.append with ignore_index=True for appending a
3333
single row (GH #430)
34+
- Add Spearman and Kendall correlation options to Series.corr and
35+
DataFrame.corr (GH #428)
3436

3537
**Improvements to existing features**
3638
- Improve memory usage of `DataFrame.describe` (do not copy data
3739
unnecessarily) (PR #425)
3840
- Use same formatting function for outputting floating point Series to console
3941
as in DataFrame (PR #420)
4042
- DataFrame.delevel will try to infer better dtype for new columns (GH #440)
41-
- Exclude non-numeric types in DataFrame.corr
43+
- Exclude non-numeric types in DataFrame.{corr, cov}
4244

4345
**Bug fixes**
4446

@@ -52,7 +54,6 @@ pandas 0.6.1
5254
- Fix groupby exception raised with as_index=False and single column selected
5355
(GH #421)
5456

55-
5657
Thanks
5758
------
5859
- Ralph Bean

pandas/core/frame.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2563,26 +2563,36 @@ def _join_index(self, other, how, lsuffix, rsuffix):
25632563
#----------------------------------------------------------------------
25642564
# Statistical methods, etc.
25652565

2566-
def corr(self):
2566+
def corr(self, method='pearson'):
25672567
"""
25682568
Compute pairwise correlation of columns, excluding NA/null values
25692569
2570+
Parameters
2571+
----------
2572+
method : {'pearson', 'kendall', 'spearman'}
2573+
pearson : standard correlation coefficient
2574+
kendall : Kendall Tau correlation coefficient
2575+
spearman : Spearman rank correlation
2576+
25702577
Returns
25712578
-------
25722579
y : DataFrame
25732580
"""
25742581
cols = self._get_numeric_columns()
25752582
mat = self.as_matrix(cols).T
2576-
baseCov = np.cov(mat)
2577-
2578-
sigma = np.sqrt(np.diag(baseCov))
2579-
correl = baseCov / np.outer(sigma, sigma)
2580-
2581-
# Get the covariance with items that have NaN values
2582-
for i, j, ac, bc in self._cov_helper(mat):
2583-
c = np.corrcoef(ac, bc)[0, 1]
2584-
correl[i, j] = c
2585-
correl[j, i] = c
2583+
corrf = nanops.get_corr_func(method)
2584+
K = len(cols)
2585+
correl = np.empty((K, K), dtype=float)
2586+
mask = np.isfinite(mat)
2587+
for i, ac in enumerate(mat):
2588+
for j, bc in enumerate(mat):
2589+
valid = mask[i] & mask[j]
2590+
if not valid.all():
2591+
c = corrf(ac[valid], bc[valid])
2592+
else:
2593+
c = corrf(ac, bc)
2594+
correl[i, j] = c
2595+
correl[j, i] = c
25862596

25872597
return self._constructor(correl, index=cols, columns=cols)
25882598

@@ -2594,7 +2604,7 @@ def cov(self):
25942604
-------
25952605
y : DataFrame
25962606
"""
2597-
cols = self.columns
2607+
cols = self._get_numeric_columns()
25982608
mat = self.as_matrix(cols).T
25992609
baseCov = np.cov(mat)
26002610

pandas/core/nanops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,48 @@ def _zero_out_fperr(arg):
218218
return np.where(np.abs(arg) < 1e-14, 0, arg)
219219
else:
220220
return 0 if np.abs(arg) < 1e-14 else arg
221+
222+
def nancorr(a, b, method='pearson'):
223+
"""
224+
a, b: ndarrays
225+
"""
226+
assert(len(a) == len(b))
227+
if len(a) == 0:
228+
return np.nan
229+
230+
valid = notnull(a) & notnull(b)
231+
if not valid.all():
232+
a = a[valid]
233+
b = b[valid]
234+
235+
f = get_corr_func(method)
236+
return f(a, b)
237+
238+
def get_corr_func(method):
239+
if method in ['kendall', 'spearman']:
240+
from scipy.stats import kendalltau, spearmanr
241+
242+
def _pearson(a, b):
243+
return np.corrcoef(a, b)[0, 1]
244+
def _kendall(a, b):
245+
return kendalltau(a, b)[0]
246+
def _spearman(a, b):
247+
return spearmanr(a, b)[0]
248+
249+
_cor_methods = {
250+
'pearson' : _pearson,
251+
'kendall' : _kendall,
252+
'spearman' : _spearman
253+
}
254+
return _cor_methods[method]
255+
256+
def nancov(a, b):
257+
assert(len(a) == len(b))
258+
if len(a) == 0:
259+
return np.nan
260+
261+
valid = notnull(a) & notnull(b)
262+
if not valid.all():
263+
a = a[valid]
264+
b = b[valid]
265+
return np.cov(a, b)[0, 1]

pandas/core/series.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -867,22 +867,24 @@ def describe(self):
867867

868868
return Series(data, index=names)
869869

870-
def corr(self, other):
870+
def corr(self, other, method='pearson'):
871871
"""
872872
Compute correlation two Series, excluding missing values
873873
874874
Parameters
875875
----------
876876
other : Series
877+
method : {'pearson', 'kendall', 'spearman'}
878+
pearson : standard correlation coefficient
879+
kendall : Kendall Tau correlation coefficient
880+
spearman : Spearman rank correlation
877881
878882
Returns
879883
-------
880884
correlation : float
881885
"""
882-
this, that = self._get_nonna_aligned(other)
883-
if this is None or that is None:
884-
return nan
885-
return np.corrcoef(this, that)[0, 1]
886+
this, other = self.align(other, join='inner')
887+
return nanops.nancorr(this.values, other.values, method=method)
886888

887889
def cov(self, other):
888890
"""
@@ -896,23 +898,10 @@ def cov(self, other):
896898
-------
897899
covariance : float
898900
"""
899-
this, that = self._get_nonna_aligned(other)
900-
if this is None or that is None:
901-
return nan
902-
return np.cov(this, that)[0, 1]
903-
904-
def _get_nonna_aligned(self, other):
905-
"""
906-
Returns two sub-Series with the same index and only non-na values
907-
"""
908-
commonIdx = self.dropna().index.intersection(other.dropna().index)
909-
910-
if len(commonIdx) == 0:
911-
return None, None
912-
913-
this = self.reindex(commonIdx)
914-
that = other.reindex(commonIdx)
915-
return this, that
901+
this, other = self.align(other, join='inner')
902+
if len(this) == 0:
903+
return np.nan
904+
return nanops.nancov(this.values, other.values)
916905

917906
def diff(self, periods=1):
918907
"""

pandas/tests/test_frame.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,10 +2075,14 @@ def test_corr(self):
20752075
self.frame['A'][:5] = nan
20762076
self.frame['B'][:10] = nan
20772077

2078-
correls = self.frame.corr()
2078+
def _check_method(method='pearson'):
2079+
correls = self.frame.corr(method=method)
2080+
exp = self.frame['A'].corr(self.frame['C'], method=method)
2081+
assert_almost_equal(correls['A']['C'], exp)
20792082

2080-
assert_almost_equal(correls['A']['C'],
2081-
self.frame['A'].corr(self.frame['C']))
2083+
_check_method('pearson')
2084+
_check_method('kendall')
2085+
_check_method('spearman')
20822086

20832087
# exclude non-numeric types
20842088
result = self.mixed_frame.corr()
@@ -2093,6 +2097,11 @@ def test_cov(self):
20932097
assert_almost_equal(cov['A']['C'],
20942098
self.frame['A'].cov(self.frame['C']))
20952099

2100+
# exclude non-numeric types
2101+
result = self.mixed_frame.cov()
2102+
expected = self.mixed_frame.ix[:, ['A', 'B', 'C', 'D']].cov()
2103+
assert_frame_equal(result, expected)
2104+
20962105
def test_corrwith(self):
20972106
a = self.tsframe
20982107
noise = Series(randn(len(a)), index=a.index)

pandas/tests/test_multilevel.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,7 @@
1313
assert_frame_equal)
1414
import pandas.core.common as com
1515
import pandas.util.testing as tm
16-
17-
try:
18-
from itertools import product as cart_product
19-
except ImportError: # python 2.5
20-
def cart_product(*args, **kwds):
21-
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
22-
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
23-
pools = map(tuple, args) * kwds.get('repeat', 1)
24-
result = [[]]
25-
for pool in pools:
26-
result = [x+[y] for x in result for y in pool]
27-
for prod in result:
28-
yield tuple(prod)
16+
from pandas.util.compat import product as cart_product
2917

3018
class TestMultiLevel(unittest.TestCase):
3119

pandas/tests/test_series.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,8 @@ def test_combine_first(self):
879879
assert_series_equal(s, result)
880880

881881
def test_corr(self):
882+
import scipy.stats as stats
883+
882884
# full overlap
883885
self.assertAlmostEqual(self.ts.corr(self.ts), 1)
884886

@@ -888,7 +890,38 @@ def test_corr(self):
888890
# No overlap
889891
self.assert_(np.isnan(self.ts[::2].corr(self.ts[1::2])))
890892

891-
# additional checks?
893+
A = tm.makeTimeSeries()
894+
B = tm.makeTimeSeries()
895+
result = A.corr(B)
896+
expected, _ = stats.pearsonr(A, B)
897+
self.assertAlmostEqual(result, expected)
898+
899+
def test_corr_rank(self):
900+
import scipy.stats as stats
901+
# kendall and spearman
902+
903+
A = tm.makeTimeSeries()
904+
B = tm.makeTimeSeries()
905+
A[-5:] = A[:5]
906+
result = A.corr(B, method='kendall')
907+
expected = stats.kendalltau(A, B)[0]
908+
self.assertAlmostEqual(result, expected)
909+
910+
result = A.corr(B, method='spearman')
911+
expected = stats.spearmanr(A, B)[0]
912+
self.assertAlmostEqual(result, expected)
913+
914+
# results from R
915+
A = Series([-0.89926396, 0.94209606, -1.03289164, -0.95445587,
916+
0.76910310, -0.06430576, -2.09704447, 0.40660407,
917+
-0.89926396, 0.94209606])
918+
B = Series([-1.01270225, -0.62210117, -1.56895827, 0.59592943,
919+
-0.01680292, 1.17258718, -1.06009347, -0.10222060,
920+
-0.89076239, 0.89372375])
921+
kexp = 0.4319297
922+
sexp = 0.5853767
923+
self.assertAlmostEqual(A.corr(B, method='kendall'), kexp)
924+
self.assertAlmostEqual(A.corr(B, method='spearman'), sexp)
892925

893926
def test_cov(self):
894927
# full overlap

pandas/util/compat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# itertools.product not in Python 2.5
2+
3+
try:
4+
from itertools import product
5+
except ImportError: # python 2.5
6+
def product(*args, **kwds):
7+
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
8+
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
9+
pools = map(tuple, args) * kwds.get('repeat', 1)
10+
result = [[]]
11+
for pool in pools:
12+
result = [x+[y] for x in result for y in pool]
13+
for prod in result:
14+
yield tuple(prod)

0 commit comments

Comments
 (0)