Skip to content

Commit 573caff

Browse files
authored
BUG: scatter_matrix raising with 2d axis passed (#38668)
1 parent 5fecf47 commit 573caff

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ Period
262262
Plotting
263263
^^^^^^^^
264264

265+
- Bug in :func:`scatter_matrix` raising when 2d ``ax`` argument passed (:issue:`16253`)
265266
-
266267
-
267268

pandas/plotting/_matplotlib/tools.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def create_subplots(
196196
fig = plt.figure(**fig_kw)
197197
else:
198198
if is_list_like(ax):
199-
ax = flatten_axes(ax)
199+
if squeeze:
200+
ax = flatten_axes(ax)
200201
if layout is not None:
201202
warnings.warn(
202203
"When passing multiple axes, layout keyword is ignored", UserWarning
@@ -208,8 +209,8 @@ def create_subplots(
208209
UserWarning,
209210
stacklevel=4,
210211
)
211-
if len(ax) == naxes:
212-
fig = ax[0].get_figure()
212+
if ax.size == naxes:
213+
fig = ax.flat[0].get_figure()
213214
return fig, ax
214215
else:
215216
raise ValueError(

pandas/tests/plotting/test_misc.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,16 @@ def test_bootstrap_plot(self):
9494
@td.skip_if_no_mpl
9595
class TestDataFramePlots(TestPlotBase):
9696
@td.skip_if_no_scipy
97-
def test_scatter_matrix_axis(self):
97+
@pytest.mark.parametrize("pass_axis", [False, True])
98+
def test_scatter_matrix_axis(self, pass_axis):
9899
from pandas.plotting._matplotlib.compat import mpl_ge_3_0_0
99100

100101
scatter_matrix = plotting.scatter_matrix
101102

103+
ax = None
104+
if pass_axis:
105+
_, ax = self.plt.subplots(3, 3)
106+
102107
with tm.RNGContext(42):
103108
df = DataFrame(np.random.randn(100, 3))
104109

@@ -107,7 +112,11 @@ def test_scatter_matrix_axis(self):
107112
UserWarning, raise_on_extra_warnings=mpl_ge_3_0_0()
108113
):
109114
axes = _check_plot_works(
110-
scatter_matrix, filterwarnings="always", frame=df, range_padding=0.1
115+
scatter_matrix,
116+
filterwarnings="always",
117+
frame=df,
118+
range_padding=0.1,
119+
ax=ax,
111120
)
112121
axes0_labels = axes[0][0].yaxis.get_majorticklabels()
113122

@@ -121,7 +130,11 @@ def test_scatter_matrix_axis(self):
121130
# we are plotting multiples on a sub-plot
122131
with tm.assert_produces_warning(UserWarning):
123132
axes = _check_plot_works(
124-
scatter_matrix, filterwarnings="always", frame=df, range_padding=0.1
133+
scatter_matrix,
134+
filterwarnings="always",
135+
frame=df,
136+
range_padding=0.1,
137+
ax=ax,
125138
)
126139
axes0_labels = axes[0][0].yaxis.get_majorticklabels()
127140
expected = ["-1.0", "-0.5", "0.0"]

0 commit comments

Comments
 (0)