Skip to content

Commit 10dcb68

Browse files
ENH: correlation function accepts method being a callable
- other than the listed strings for the `method` argument, accept a callable for generic correlation calculations - minor fix of = to == in requirements file
1 parent 8a1c8ad commit 10dcb68

File tree

6 files changed

+65
-4
lines changed

6 files changed

+65
-4
lines changed

doc/source/computation.rst

+15
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,21 @@ Like ``cov``, ``corr`` also supports the optional ``min_periods`` keyword:
153153
frame.corr(min_periods=12)
154154
155155
156+
.. versionadded:: 0.24.0
157+
158+
The ``method`` argument can also be a callable for a generic correlation
159+
calculation. In this case, it should be a single function
160+
that produces a single value from two ndarray inputs. Suppose we wanted to
161+
compute the correlation based on histogram intersection:
162+
163+
.. ipython:: python
164+
165+
# histogram intersection
166+
histogram_intersection = lambda a, b: np.minimum(
167+
np.true_divide(a, a.sum()), np.true_divide(b, b.sum())
168+
).sum()
169+
frame.corr(method=histogram_intersection)
170+
156171
A related method :meth:`~DataFrame.corrwith` is implemented on DataFrame to
157172
compute the correlation between like-labeled Series contained in different
158173
DataFrame objects.

doc/source/whatsnew/v0.24.0.txt

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ New features
1717

1818
- ``ExcelWriter`` now accepts ``mode`` as a keyword argument, enabling append to existing workbooks when using the ``openpyxl`` engine (:issue:`3441`)
1919

20+
- :meth:`DataFrame.corr` and :meth:`Series.corr` now accept a callable for generic calculation methods of correlation, e.g. histogram intersection (:issue:`22684`)
21+
22+
2023
.. _whatsnew_0240.enhancements.extension_array_operators:
2124

2225
``ExtensionArray`` operator support

pandas/core/frame.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -6652,10 +6652,14 @@ def corr(self, method='pearson', min_periods=1):
66526652
66536653
Parameters
66546654
----------
6655-
method : {'pearson', 'kendall', 'spearman'}
6655+
method : {'pearson', 'kendall', 'spearman'} or callable
66566656
* pearson : standard correlation coefficient
66576657
* kendall : Kendall Tau correlation coefficient
66586658
* spearman : Spearman rank correlation
6659+
* callable: callable with input two 1d ndarrays
6660+
and returning a float
6661+
.. versionadded:: 0.24.0
6662+
66596663
min_periods : int, optional
66606664
Minimum number of observations required per pair of columns
66616665
to have a valid result. Currently only available for pearson
@@ -6675,7 +6679,7 @@ def corr(self, method='pearson', min_periods=1):
66756679
elif method == 'spearman':
66766680
correl = libalgos.nancorr_spearman(ensure_float64(mat),
66776681
minp=min_periods)
6678-
elif method == 'kendall':
6682+
elif method == 'kendall' or callable(method):
66796683
if min_periods is None:
66806684
min_periods = 1
66816685
mat = ensure_float64(mat).T

pandas/core/nanops.py

+3
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,9 @@ def nancorr(a, b, method='pearson', min_periods=None):
764764

765765

766766
def get_corr_func(method):
767+
if callable(method):
768+
return method
769+
767770
if method in ['kendall', 'spearman']:
768771
from scipy.stats import kendalltau, spearmanr
769772

pandas/core/series.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1913,10 +1913,14 @@ def corr(self, other, method='pearson', min_periods=None):
19131913
Parameters
19141914
----------
19151915
other : Series
1916-
method : {'pearson', 'kendall', 'spearman'}
1916+
method : {'pearson', 'kendall', 'spearman'} or callable
19171917
* pearson : standard correlation coefficient
19181918
* kendall : Kendall Tau correlation coefficient
19191919
* spearman : Spearman rank correlation
1920+
* callable: callable with input two 1d ndarray
1921+
and returning a float
1922+
.. versionadded:: 0.24.0
1923+
19201924
min_periods : int, optional
19211925
Minimum number of observations needed to have a valid result
19221926
@@ -1929,7 +1933,7 @@ def corr(self, other, method='pearson', min_periods=None):
19291933
if len(this) == 0:
19301934
return np.nan
19311935

1932-
if method in ['pearson', 'spearman', 'kendall']:
1936+
if method in ['pearson', 'spearman', 'kendall'] or callable(method):
19331937
return nanops.nancorr(this.values, other.values, method=method,
19341938
min_periods=min_periods)
19351939

pandas/tests/series/test_analytics.py

+32
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,38 @@ def test_corr_invalid_method(self):
789789
with tm.assert_raises_regex(ValueError, msg):
790790
s1.corr(s2, method="____")
791791

792+
def test_corr_callable_method(self):
793+
# simple correlation example
794+
# returns 1 if exact equality, 0 otherwise
795+
my_corr = lambda a, b: 1. if (a==b).all() else 0.
796+
797+
# simple example
798+
s1 = Series([1, 2, 3, 4, 5])
799+
s2 = Series([5, 4, 3, 2, 1])
800+
expected_1 = 0
801+
tm.assert_almost_equal(
802+
s1.corr(s2, method=my_corr),
803+
expected_1)
804+
805+
# full overlap
806+
tm.assert_almost_equal(
807+
self.ts.corr(self.ts, method=my_corr), 1.)
808+
809+
# partial overlap
810+
tm.assert_almost_equal(
811+
self.ts[:15].corr(self.ts[5:], method=my_corr), 1.)
812+
813+
# No overlap
814+
assert np.isnan(
815+
self.ts[::2].corr(self.ts[1::2], method=my_corr))
816+
817+
# dataframe example
818+
df = pd.DataFrame([s1, s2])
819+
expected_2 = pd.DataFrame([
820+
{0: 1., 1: expected_1}, {0: expected_1, 1: 1.}])
821+
tm.assert_almost_equal(
822+
df.transpose().corr(method=my_corr), expected_2)
823+
792824
def test_cov(self):
793825
# full overlap
794826
tm.assert_almost_equal(self.ts.cov(self.ts), self.ts.std() ** 2)

0 commit comments

Comments
 (0)