From ee32e334338980c77e5783978770d6834136d487 Mon Sep 17 00:00:00 2001 From: Shadi Akiki Date: Mon, 15 Jan 2024 08:08:46 -0500 Subject: [PATCH] ENH: scatter_matrix new parameter nondiagonal to allow hexbin plots instead of scatter --- pandas/plotting/_matplotlib/misc.py | 12 +++++++++--- pandas/plotting/_misc.py | 5 +++++ pandas/tests/plotting/test_misc.py | 11 +++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/pandas/plotting/_matplotlib/misc.py b/pandas/plotting/_matplotlib/misc.py index 1f9212587e05e..2af32afcb2329 100644 --- a/pandas/plotting/_matplotlib/misc.py +++ b/pandas/plotting/_matplotlib/misc.py @@ -38,6 +38,7 @@ def scatter_matrix( ax=None, grid: bool = False, diagonal: str = "hist", + nondiagonal: str = "scatter", marker: str = ".", density_kwds=None, hist_kwds=None, @@ -93,9 +94,14 @@ def scatter_matrix( else: common = (mask[a] & mask[b]).values - ax.scatter( - df[b][common], df[a][common], marker=marker, alpha=alpha, **kwds - ) + if nondiagonal == "scatter": + ax.scatter( + df[b][common], df[a][common], marker=marker, alpha=alpha, **kwds + ) + elif nondiagonal == "hexbin": + ax.hexbin( + df[b][common], df[a][common], alpha=alpha, **kwds + ) ax.set_xlim(boundaries_list[j]) ax.set_ylim(boundaries_list[i]) diff --git a/pandas/plotting/_misc.py b/pandas/plotting/_misc.py index 18db460d388a4..fe534d02cc9b2 100644 --- a/pandas/plotting/_misc.py +++ b/pandas/plotting/_misc.py @@ -159,6 +159,7 @@ def scatter_matrix( ax: Axes | None = None, grid: bool = False, diagonal: str = "hist", + nondiagonal: str = "scatter", marker: str = ".", density_kwds: Mapping[str, Any] | None = None, hist_kwds: Mapping[str, Any] | None = None, @@ -181,6 +182,9 @@ def scatter_matrix( diagonal : {'hist', 'kde'} Pick between 'kde' and 'hist' for either Kernel Density Estimation or Histogram plot in the diagonal. + nondiagonal : {'scatter', 'hexbin'} + Pick between 'scatter' and 'hexbin' for either scatter or + 2D hexagonal binning plot in the diagonal. marker : str, optional Matplotlib marker type, default '.'. density_kwds : keywords @@ -224,6 +228,7 @@ def scatter_matrix( ax=ax, grid=grid, diagonal=diagonal, + nondiagonal=nondiagonal, marker=marker, density_kwds=density_kwds, hist_kwds=hist_kwds, diff --git a/pandas/tests/plotting/test_misc.py b/pandas/tests/plotting/test_misc.py index cfb657c2a800f..ca031408b2a69 100644 --- a/pandas/tests/plotting/test_misc.py +++ b/pandas/tests/plotting/test_misc.py @@ -196,6 +196,17 @@ def test_scatter_matrix_axis_smaller(self, pass_axis): _check_text_labels(axes0_labels, expected) _check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) + @pytest.mark.parametrize("nondiagonal", ["scatter", "hexbin"]) + def test_scatter_matrix_nondiagonal(self, nondiagonal): + pytest.importorskip("scipy") + scatter_matrix = plotting.scatter_matrix + df = DataFrame(np.random.default_rng(2).standard_normal((100, 3))) + _check_plot_works( + scatter_matrix, + frame=df, + nondiagonal=nondiagonal, + ) + @pytest.mark.slow def test_andrews_curves_no_warning(self, iris): from pandas.plotting import andrews_curves