diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index e13e6380905f2..d4eb797047c78 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -178,6 +178,7 @@ Other enhancements - :meth:`DataFrame.__pos__`, :meth:`DataFrame.__neg__` now retain ``ExtensionDtype`` dtypes (:issue:`43883`) - The error raised when an optional dependency can't be imported now includes the original exception, for easier investigation (:issue:`43882`) - Added :meth:`.ExponentialMovingWindow.sum` (:issue:`13297`) +- :meth:`DataFrame.corr` now accept the argument ``calculate_diagonal`` to allow results returned from callable to be used as diagonal elements of the correlation matrix instead of setting them to ones (:issue:`25781`) - :meth:`DataFrame.dropna` now accepts a single label as ``subset`` along with array-like (:issue:`41021`) - diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 5afb19f1d91fe..0c13c03ea5f22 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -9467,6 +9467,7 @@ def corr( self, method: str | Callable[[np.ndarray, np.ndarray], float] = "pearson", min_periods: int = 1, + calculate_diagonal: bool = False, ) -> DataFrame: """ Compute pairwise correlation of columns, excluding NA/null values. @@ -9487,6 +9488,12 @@ def corr( Minimum number of observations required per pair of columns to have a valid result. Currently only available for Pearson and Spearman correlation. + calculate_diagonal : bool, optional + Whether to calculate pairwise correlation using supplied callable. + Ignored when method argument is not callable. If False, pairwise + correlation between a column and itself is default to 1. + + .. versionadded:: 1.4.0 Returns ------- @@ -9536,7 +9543,7 @@ def corr( valid = mask[i] & mask[j] if valid.sum() < min_periods: c = np.nan - elif i == j: + elif i == j and not calculate_diagonal: c = 1.0 elif not valid.all(): c = corrf(ac[valid], bc[valid]) diff --git a/pandas/tests/frame/methods/test_cov_corr.py b/pandas/tests/frame/methods/test_cov_corr.py index 3dbf49df72558..0621c2bf0602e 100644 --- a/pandas/tests/frame/methods/test_cov_corr.py +++ b/pandas/tests/frame/methods/test_cov_corr.py @@ -236,6 +236,21 @@ def test_corr_min_periods_greater_than_length(self, method): ) tm.assert_frame_equal(result, expected) + @pytest.mark.filterwarnings("ignore: An input array is constant") + @td.skip_if_no_scipy + @pytest.mark.parametrize("array_creator", [np.ones, np.zeros, np.random.random]) + def test_corr_diagonal_not_ones(self, array_creator): + from scipy.stats import pearsonr + + frame_size = 4 + df = DataFrame(array_creator((frame_size, frame_size))) + cor_mat = df.corr( + method=lambda x, y: pearsonr(x, y)[0], calculate_diagonal=True + ) + result_diag = [cor_mat.loc[i, i] for i in range(frame_size)] + expected_diag = [pearsonr(df[i], df[i])[0] for i in range(frame_size)] + tm.assert_almost_equal(result_diag, expected_diag) + class TestDataFrameCorrWith: def test_corrwith(self, datetime_frame):