Skip to content

Commit 63f474d

Browse files
author
liwh
committed
BUG: Allow pairwise calcuation when comparing the column with itself (#25781)
1 parent c021d33 commit 63f474d

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ Other enhancements
129129
- :meth:`DataFrame.__pos__`, :meth:`DataFrame.__neg__` now retain ``ExtensionDtype`` dtypes (:issue:`43883`)
130130
- The error raised when an optional dependency can't be imported now includes the original exception, for easier investigation (:issue:`43882`)
131131
- Added :meth:`.ExponentialMovingWindow.sum` (:issue:`13297`)
132+
- :meth:`DataFrame.corr` now accept the argument ``calculate_diagonal`` to allow results returned from callable to be used as diagonal elements of the correlation matrix instead of setting them to ones (:issue:`25781`)
132133

133134
.. ---------------------------------------------------------------------------
134135

pandas/core/frame.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -9402,6 +9402,7 @@ def corr(
94029402
self,
94039403
method: str | Callable[[np.ndarray, np.ndarray], float] = "pearson",
94049404
min_periods: int = 1,
9405+
calculate_diagonal: bool = False,
94059406
) -> DataFrame:
94069407
"""
94079408
Compute pairwise correlation of columns, excluding NA/null values.
@@ -9422,6 +9423,12 @@ def corr(
94229423
Minimum number of observations required per pair of columns
94239424
to have a valid result. Currently only available for Pearson
94249425
and Spearman correlation.
9426+
calculate_diagonal : bool, optional
9427+
Whether to calculate pairwise correlation using supplied callable.
9428+
Ignored when method argument is not callable. If False, pairwise
9429+
correlation between a column and itself is default to 1.
9430+
9431+
.. versionadded:: 1.4.0
94259432
94269433
Returns
94279434
-------
@@ -9471,7 +9478,7 @@ def corr(
94719478
valid = mask[i] & mask[j]
94729479
if valid.sum() < min_periods:
94739480
c = np.nan
9474-
elif i == j:
9481+
elif i == j and not calculate_diagonal:
94759482
c = 1.0
94769483
elif not valid.all():
94779484
c = corrf(ac[valid], bc[valid])

pandas/tests/frame/methods/test_cov_corr.py

+15
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,21 @@ def test_corr_min_periods_greater_than_length(self, method):
236236
)
237237
tm.assert_frame_equal(result, expected)
238238

239+
@pytest.mark.filterwarnings("ignore: An input array is constant")
240+
@td.skip_if_no_scipy
241+
@pytest.mark.parametrize("array_creator", [np.ones, np.zeros, np.random.random])
242+
def test_corr_diagonal_not_ones(self, array_creator):
243+
from scipy.stats import pearsonr
244+
245+
frame_size = 4
246+
df = DataFrame(array_creator((frame_size, frame_size)))
247+
cor_mat = df.corr(
248+
method=lambda x, y: pearsonr(x, y)[0], calculate_diagonal=True
249+
)
250+
result_diag = [cor_mat.loc[i, i] for i in range(frame_size)]
251+
expected_diag = [pearsonr(df[i], df[i])[0] for i in range(frame_size)]
252+
tm.assert_almost_equal(result_diag, expected_diag)
253+
239254

240255
class TestDataFrameCorrWith:
241256
def test_corrwith(self, datetime_frame):

0 commit comments

Comments
 (0)