Skip to content

Commit a393675

Browse files
shadiakiki1986jreback
authored andcommitted
ENH: correlation function accepts method being a callable (#22684)
1 parent 4a459b8 commit a393675

File tree

6 files changed

+85
-4
lines changed

6 files changed

+85
-4
lines changed

doc/source/computation.rst

+15
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,21 @@ Like ``cov``, ``corr`` also supports the optional ``min_periods`` keyword:
153153
frame.corr(min_periods=12)
154154
155155
156+
.. versionadded:: 0.24.0
157+
158+
The ``method`` argument can also be a callable for a generic correlation
159+
calculation. In this case, it should be a single function
160+
that produces a single value from two ndarray inputs. Suppose we wanted to
161+
compute the correlation based on histogram intersection:
162+
163+
.. ipython:: python
164+
165+
# histogram intersection
166+
histogram_intersection = lambda a, b: np.minimum(
167+
np.true_divide(a, a.sum()), np.true_divide(b, b.sum())
168+
).sum()
169+
frame.corr(method=histogram_intersection)
170+
156171
A related method :meth:`~DataFrame.corrwith` is implemented on DataFrame to
157172
compute the correlation between like-labeled Series contained in different
158173
DataFrame objects.

doc/source/whatsnew/v0.24.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ New features
2020
- :func:`DataFrame.to_parquet` now accepts ``index`` as an argument, allowing
2121
the user to override the engine's default behavior to include or omit the
2222
dataframe's indexes from the resulting Parquet file. (:issue:`20768`)
23+
- :meth:`DataFrame.corr` and :meth:`Series.corr` now accept a callable for generic calculation methods of correlation, e.g. histogram intersection (:issue:`22684`)
24+
2325

2426
.. _whatsnew_0240.enhancements.extension_array_operators:
2527

pandas/core/frame.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -6672,10 +6672,14 @@ def corr(self, method='pearson', min_periods=1):
66726672
66736673
Parameters
66746674
----------
6675-
method : {'pearson', 'kendall', 'spearman'}
6675+
method : {'pearson', 'kendall', 'spearman'} or callable
66766676
* pearson : standard correlation coefficient
66776677
* kendall : Kendall Tau correlation coefficient
66786678
* spearman : Spearman rank correlation
6679+
* callable: callable with input two 1d ndarrays
6680+
and returning a float
6681+
.. versionadded:: 0.24.0
6682+
66796683
min_periods : int, optional
66806684
Minimum number of observations required per pair of columns
66816685
to have a valid result. Currently only available for pearson
@@ -6684,6 +6688,18 @@ def corr(self, method='pearson', min_periods=1):
66846688
Returns
66856689
-------
66866690
y : DataFrame
6691+
6692+
Examples
6693+
--------
6694+
>>> import numpy as np
6695+
>>> histogram_intersection = lambda a, b: np.minimum(a, b
6696+
... ).sum().round(decimals=1)
6697+
>>> df = pd.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)],
6698+
... columns=['dogs', 'cats'])
6699+
>>> df.corr(method=histogram_intersection)
6700+
dogs cats
6701+
dogs 1.0 0.3
6702+
cats 0.3 1.0
66876703
"""
66886704
numeric_df = self._get_numeric_data()
66896705
cols = numeric_df.columns
@@ -6695,7 +6711,7 @@ def corr(self, method='pearson', min_periods=1):
66956711
elif method == 'spearman':
66966712
correl = libalgos.nancorr_spearman(ensure_float64(mat),
66976713
minp=min_periods)
6698-
elif method == 'kendall':
6714+
elif method == 'kendall' or callable(method):
66996715
if min_periods is None:
67006716
min_periods = 1
67016717
mat = ensure_float64(mat).T

pandas/core/nanops.py

+2
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,8 @@ def nancorr(a, b, method='pearson', min_periods=None):
766766
def get_corr_func(method):
767767
if method in ['kendall', 'spearman']:
768768
from scipy.stats import kendalltau, spearmanr
769+
elif callable(method):
770+
return method
769771

770772
def _pearson(a, b):
771773
return np.corrcoef(a, b)[0, 1]

pandas/core/series.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -1910,23 +1910,37 @@ def corr(self, other, method='pearson', min_periods=None):
19101910
Parameters
19111911
----------
19121912
other : Series
1913-
method : {'pearson', 'kendall', 'spearman'}
1913+
method : {'pearson', 'kendall', 'spearman'} or callable
19141914
* pearson : standard correlation coefficient
19151915
* kendall : Kendall Tau correlation coefficient
19161916
* spearman : Spearman rank correlation
1917+
* callable: callable with input two 1d ndarray
1918+
and returning a float
1919+
.. versionadded:: 0.24.0
1920+
19171921
min_periods : int, optional
19181922
Minimum number of observations needed to have a valid result
19191923
19201924
19211925
Returns
19221926
-------
19231927
correlation : float
1928+
1929+
Examples
1930+
--------
1931+
>>> import numpy as np
1932+
>>> histogram_intersection = lambda a, b: np.minimum(a, b
1933+
... ).sum().round(decimals=1)
1934+
>>> s1 = pd.Series([.2, .0, .6, .2])
1935+
>>> s2 = pd.Series([.3, .6, .0, .1])
1936+
>>> s1.corr(s2, method=histogram_intersection)
1937+
0.3
19241938
"""
19251939
this, other = self.align(other, join='inner', copy=False)
19261940
if len(this) == 0:
19271941
return np.nan
19281942

1929-
if method in ['pearson', 'spearman', 'kendall']:
1943+
if method in ['pearson', 'spearman', 'kendall'] or callable(method):
19301944
return nanops.nancorr(this.values, other.values, method=method,
19311945
min_periods=min_periods)
19321946

pandas/tests/series/test_analytics.py

+32
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,38 @@ def test_corr_invalid_method(self):
789789
with tm.assert_raises_regex(ValueError, msg):
790790
s1.corr(s2, method="____")
791791

792+
def test_corr_callable_method(self):
793+
# simple correlation example
794+
# returns 1 if exact equality, 0 otherwise
795+
my_corr = lambda a, b: 1. if (a == b).all() else 0.
796+
797+
# simple example
798+
s1 = Series([1, 2, 3, 4, 5])
799+
s2 = Series([5, 4, 3, 2, 1])
800+
expected = 0
801+
tm.assert_almost_equal(
802+
s1.corr(s2, method=my_corr),
803+
expected)
804+
805+
# full overlap
806+
tm.assert_almost_equal(
807+
self.ts.corr(self.ts, method=my_corr), 1.)
808+
809+
# partial overlap
810+
tm.assert_almost_equal(
811+
self.ts[:15].corr(self.ts[5:], method=my_corr), 1.)
812+
813+
# No overlap
814+
assert np.isnan(
815+
self.ts[::2].corr(self.ts[1::2], method=my_corr))
816+
817+
# dataframe example
818+
df = pd.DataFrame([s1, s2])
819+
expected = pd.DataFrame([
820+
{0: 1., 1: 0}, {0: 0, 1: 1.}])
821+
tm.assert_almost_equal(
822+
df.transpose().corr(method=my_corr), expected)
823+
792824
def test_cov(self):
793825
# full overlap
794826
tm.assert_almost_equal(self.ts.cov(self.ts), self.ts.std() ** 2)

0 commit comments

Comments
 (0)