Skip to content

Commit 7d561dd

Browse files
Daniel SaxtonPingviinituutti
Daniel Saxton
authored andcommitted
ENH: Enable DataFrame.corrwith to compute rank correlations (pandas-dev#22375)
1 parent d70580f commit 7d561dd

File tree

4 files changed

+92
-18
lines changed

4 files changed

+92
-18
lines changed

asv_bench/benchmarks/stat_ops.py

+7
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def setup(self, method, use_bottleneck):
106106
from pandas.core import nanops
107107
nanops._USE_BOTTLENECK = use_bottleneck
108108
self.df = pd.DataFrame(np.random.randn(1000, 30))
109+
self.df2 = pd.DataFrame(np.random.randn(1000, 30))
109110
self.s = pd.Series(np.random.randn(1000))
110111
self.s2 = pd.Series(np.random.randn(1000))
111112

@@ -115,6 +116,12 @@ def time_corr(self, method, use_bottleneck):
115116
def time_corr_series(self, method, use_bottleneck):
116117
self.s.corr(self.s2, method=method)
117118

119+
def time_corrwith_cols(self, method, use_bottleneck):
120+
self.df.corrwith(self.df2, method=method)
121+
122+
def time_corrwith_rows(self, method, use_bottleneck):
123+
self.df.corrwith(self.df2, axis=1, method=method)
124+
118125

119126
class Covariance(object):
120127

doc/source/whatsnew/v0.24.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ Other Enhancements
414414
- :meth:`DataFrame.to_records` now accepts ``index_dtypes`` and ``column_dtypes`` parameters to allow different data types in stored column and index records (:issue:`18146`)
415415
- :class:`IntervalIndex` has gained the :attr:`~IntervalIndex.is_overlapping` attribute to indicate if the ``IntervalIndex`` contains any overlapping intervals (:issue:`23309`)
416416
- :func:`pandas.DataFrame.to_sql` has gained the ``method`` argument to control SQL insertion clause. See the :ref:`insertion method <io.sql.method>` section in the documentation. (:issue:`8953`)
417+
- :meth:`DataFrame.corrwith` now supports Spearman's rank correlation, Kendall's tau as well as callable correlation methods. (:issue:`21925`)
417418

418419
.. _whatsnew_0240.api_breaking:
419420

pandas/core/frame.py

+58-18
Original file line numberDiff line numberDiff line change
@@ -6957,6 +6957,11 @@ def corr(self, method='pearson', min_periods=1):
69576957
dogs cats
69586958
dogs 1.0 0.3
69596959
cats 0.3 1.0
6960+
6961+
See Also
6962+
-------
6963+
DataFrame.corrwith
6964+
Series.corr
69606965
"""
69616966
numeric_df = self._get_numeric_data()
69626967
cols = numeric_df.columns
@@ -7110,54 +7115,89 @@ def cov(self, min_periods=None):
71107115

71117116
return self._constructor(baseCov, index=idx, columns=cols)
71127117

7113-
def corrwith(self, other, axis=0, drop=False):
7118+
def corrwith(self, other, axis=0, drop=False, method='pearson'):
71147119
"""
7115-
Compute pairwise correlation between rows or columns of two DataFrame
7116-
objects.
7120+
Compute pairwise correlation between rows or columns of DataFrame
7121+
with rows or columns of Series or DataFrame. DataFrames are first
7122+
aligned along both axes before computing the correlations.
71177123
71187124
Parameters
71197125
----------
71207126
other : DataFrame, Series
71217127
axis : {0 or 'index', 1 or 'columns'}, default 0
71227128
0 or 'index' to compute column-wise, 1 or 'columns' for row-wise
71237129
drop : boolean, default False
7124-
Drop missing indices from result, default returns union of all
7130+
Drop missing indices from result
7131+
method : {'pearson', 'kendall', 'spearman'} or callable
7132+
* pearson : standard correlation coefficient
7133+
* kendall : Kendall Tau correlation coefficient
7134+
* spearman : Spearman rank correlation
7135+
* callable: callable with input two 1d ndarrays
7136+
and returning a float
7137+
7138+
.. versionadded:: 0.24.0
71257139
71267140
Returns
71277141
-------
71287142
correls : Series
7143+
7144+
See Also
7145+
-------
7146+
DataFrame.corr
71297147
"""
71307148
axis = self._get_axis_number(axis)
71317149
this = self._get_numeric_data()
71327150

71337151
if isinstance(other, Series):
7134-
return this.apply(other.corr, axis=axis)
7152+
return this.apply(lambda x: other.corr(x, method=method),
7153+
axis=axis)
71357154

71367155
other = other._get_numeric_data()
7137-
71387156
left, right = this.align(other, join='inner', copy=False)
71397157

7140-
# mask missing values
7141-
left = left + right * 0
7142-
right = right + left * 0
7143-
71447158
if axis == 1:
71457159
left = left.T
71467160
right = right.T
71477161

7148-
# demeaned data
7149-
ldem = left - left.mean()
7150-
rdem = right - right.mean()
7162+
if method == 'pearson':
7163+
# mask missing values
7164+
left = left + right * 0
7165+
right = right + left * 0
7166+
7167+
# demeaned data
7168+
ldem = left - left.mean()
7169+
rdem = right - right.mean()
71517170

7152-
num = (ldem * rdem).sum()
7153-
dom = (left.count() - 1) * left.std() * right.std()
7171+
num = (ldem * rdem).sum()
7172+
dom = (left.count() - 1) * left.std() * right.std()
71547173

7155-
correl = num / dom
7174+
correl = num / dom
7175+
7176+
elif method in ['kendall', 'spearman'] or callable(method):
7177+
def c(x):
7178+
return nanops.nancorr(x[0], x[1], method=method)
7179+
7180+
correl = Series(map(c,
7181+
zip(left.values.T, right.values.T)),
7182+
index=left.columns)
7183+
7184+
else:
7185+
raise ValueError("Invalid method {method} was passed, "
7186+
"valid methods are: 'pearson', 'kendall', "
7187+
"'spearman', or callable".
7188+
format(method=method))
71567189

71577190
if not drop:
7191+
# Find non-matching labels along the given axis
7192+
# and append missing correlations (GH 22375)
71587193
raxis = 1 if axis == 0 else 0
7159-
result_index = this._get_axis(raxis).union(other._get_axis(raxis))
7160-
correl = correl.reindex(result_index)
7194+
result_index = (this._get_axis(raxis).
7195+
union(other._get_axis(raxis)))
7196+
idx_diff = result_index.difference(correl.index)
7197+
7198+
if len(idx_diff) > 0:
7199+
correl = correl.append(Series([np.nan] * len(idx_diff),
7200+
index=idx_diff))
71617201

71627202
return correl
71637203

pandas/tests/frame/test_analytics.py

+26
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,32 @@ def test_corrwith_mixed_dtypes(self):
459459
expected = pd.Series(data=corrs, index=['a', 'b'])
460460
tm.assert_series_equal(result, expected)
461461

462+
def test_corrwith_dup_cols(self):
463+
# GH 21925
464+
df1 = pd.DataFrame(np.vstack([np.arange(10)] * 3).T)
465+
df2 = df1.copy()
466+
df2 = pd.concat((df2, df2[0]), axis=1)
467+
468+
result = df1.corrwith(df2)
469+
expected = pd.Series(np.ones(4), index=[0, 0, 1, 2])
470+
tm.assert_series_equal(result, expected)
471+
472+
@td.skip_if_no_scipy
473+
def test_corrwith_spearman(self):
474+
# GH 21925
475+
df = pd.DataFrame(np.random.random(size=(100, 3)))
476+
result = df.corrwith(df**2, method="spearman")
477+
expected = Series(np.ones(len(result)))
478+
tm.assert_series_equal(result, expected)
479+
480+
@td.skip_if_no_scipy
481+
def test_corrwith_kendall(self):
482+
# GH 21925
483+
df = pd.DataFrame(np.random.random(size=(100, 3)))
484+
result = df.corrwith(df**2, method="kendall")
485+
expected = Series(np.ones(len(result)))
486+
tm.assert_series_equal(result, expected)
487+
462488
def test_bool_describe_in_mixed_frame(self):
463489
df = DataFrame({
464490
'string_data': ['a', 'b', 'c', 'd', 'e'],

0 commit comments

Comments
 (0)