Skip to content

Commit bacebce

Browse files
committed
BUG: consistent subplot ax handling
1 parent 8cd3dd6 commit bacebce

File tree

3 files changed

+57
-39
lines changed

3 files changed

+57
-39
lines changed

doc/source/v0.14.1.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ Bug Fixes
176176
- Bug in ``to_timedelta`` that accepted invalid units and misinterpreted 'm/h' (:issue:`7611`, :issue: `6423`)
177177

178178
- Bug in grouped ``hist`` and ``scatter`` plots use old ``figsize`` default (:issue:`7394`)
179+
- Bug in plotting subplots with ``DataFrame.plot``, ``hist`` clears passed ``ax`` even if the number of subplots is one (:issue:`7391`).
180+
- Bug in plotting subplots with ``DataFrame.boxplot`` with ``by`` kw raises ``ValueError`` if the number of subplots exceeds 1 (:issue:`7391`).
179181

180182
- Bug in ``Panel.apply`` with a multi-index as an axis (:issue:`7469`)
181183

pandas/tests/test_graphics.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,13 @@ def test_plot(self):
859859
axes = _check_plot_works(df.plot, kind='bar', subplots=True)
860860
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
861861

862+
# When ax is supplied and required number of axes is 1,
863+
# passed ax should be used:
864+
fig, ax = self.plt.subplots()
865+
axes = df.plot(kind='bar', subplots=True, ax=ax)
866+
self.assertEqual(len(axes), 1)
867+
self.assertIs(ax.get_axes(), axes[0])
868+
862869
def test_nonnumeric_exclude(self):
863870
df = DataFrame({'A': ["x", "y", "z"], 'B': [1, 2, 3]})
864871
ax = df.plot()
@@ -1419,17 +1426,23 @@ def test_boxplot(self):
14191426

14201427
df = DataFrame(np.random.rand(10, 2), columns=['Col1', 'Col2'])
14211428
df['X'] = Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B'])
1429+
df['Y'] = Series(['A'] * 10)
14221430
_check_plot_works(df.boxplot, by='X')
14231431

1424-
# When ax is supplied, existing axes should be used:
1432+
# When ax is supplied and required number of axes is 1,
1433+
# passed ax should be used:
14251434
fig, ax = self.plt.subplots()
14261435
axes = df.boxplot('Col1', by='X', ax=ax)
14271436
self.assertIs(ax.get_axes(), axes)
14281437

1429-
# Multiple columns with an ax argument is not supported
14301438
fig, ax = self.plt.subplots()
1431-
with tm.assertRaisesRegexp(ValueError, 'existing axis'):
1432-
df.boxplot(column=['Col1', 'Col2'], by='X', ax=ax)
1439+
axes = df.groupby('Y').boxplot(ax=ax, return_type='axes')
1440+
self.assertIs(ax.get_axes(), axes['A'])
1441+
1442+
# Multiple columns with an ax argument should use same figure
1443+
fig, ax = self.plt.subplots()
1444+
axes = df.boxplot(column=['Col1', 'Col2'], by='X', ax=ax, return_type='axes')
1445+
self.assertIs(axes['Col1'].get_figure(), fig)
14331446

14341447
# When by is None, check that all relevant lines are present in the dict
14351448
fig, ax = self.plt.subplots()
@@ -2180,32 +2193,32 @@ class TestDataFrameGroupByPlots(TestPlotBase):
21802193
@slow
21812194
def test_boxplot(self):
21822195
grouped = self.hist_df.groupby(by='gender')
2183-
box = _check_plot_works(grouped.boxplot, return_type='dict')
2184-
self._check_axes_shape(self.plt.gcf().axes, axes_num=2, layout=(1, 2))
2196+
axes = _check_plot_works(grouped.boxplot, return_type='axes')
2197+
self._check_axes_shape(axes.values(), axes_num=2, layout=(1, 2))
21852198

2186-
box = _check_plot_works(grouped.boxplot, subplots=False,
2187-
return_type='dict')
2188-
self._check_axes_shape(self.plt.gcf().axes, axes_num=2, layout=(1, 2))
2199+
axes = _check_plot_works(grouped.boxplot, subplots=False,
2200+
return_type='axes')
2201+
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
21892202

21902203
tuples = lzip(string.ascii_letters[:10], range(10))
21912204
df = DataFrame(np.random.rand(10, 3),
21922205
index=MultiIndex.from_tuples(tuples))
21932206

21942207
grouped = df.groupby(level=1)
2195-
box = _check_plot_works(grouped.boxplot, return_type='dict')
2196-
self._check_axes_shape(self.plt.gcf().axes, axes_num=10, layout=(4, 3))
2208+
axes = _check_plot_works(grouped.boxplot, return_type='axes')
2209+
self._check_axes_shape(axes.values(), axes_num=10, layout=(4, 3))
21972210

2198-
box = _check_plot_works(grouped.boxplot, subplots=False,
2199-
return_type='dict')
2200-
self._check_axes_shape(self.plt.gcf().axes, axes_num=10, layout=(4, 3))
2211+
axes = _check_plot_works(grouped.boxplot, subplots=False,
2212+
return_type='axes')
2213+
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
22012214

22022215
grouped = df.unstack(level=1).groupby(level=0, axis=1)
2203-
box = _check_plot_works(grouped.boxplot, return_type='dict')
2204-
self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(2, 2))
2216+
axes = _check_plot_works(grouped.boxplot, return_type='axes')
2217+
self._check_axes_shape(axes.values(), axes_num=3, layout=(2, 2))
22052218

2206-
box = _check_plot_works(grouped.boxplot, subplots=False,
2207-
return_type='dict')
2208-
self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(2, 2))
2219+
axes = _check_plot_works(grouped.boxplot, subplots=False,
2220+
return_type='axes')
2221+
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
22092222

22102223
def test_series_plot_color_kwargs(self):
22112224
# GH1890

pandas/tools/plotting.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2665,7 +2665,8 @@ def plot_group(group, ax):
26652665

26662666

26672667
def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
2668-
rot=0, grid=True, figsize=None, layout=None, **kwds):
2668+
rot=0, grid=True, ax=None, figsize=None,
2669+
layout=None, **kwds):
26692670
"""
26702671
Make box plots from DataFrameGroupBy data.
26712672
@@ -2712,7 +2713,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
27122713
naxes = len(grouped)
27132714
nrows, ncols = _get_layout(naxes, layout=layout)
27142715
fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False,
2715-
sharex=False, sharey=True)
2716+
ax=ax, sharex=False, sharey=True, figsize=figsize)
27162717
axes = _flatten(axes)
27172718

27182719
ret = compat.OrderedDict()
@@ -2733,7 +2734,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
27332734
else:
27342735
df = frames[0]
27352736
ret = df.boxplot(column=column, fontsize=fontsize, rot=rot,
2736-
grid=grid, figsize=figsize, layout=layout, **kwds)
2737+
grid=grid, ax=ax, figsize=figsize, layout=layout, **kwds)
27372738
return ret
27382739

27392740

@@ -2779,17 +2780,10 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
27792780
by = [by]
27802781
columns = data._get_numeric_data().columns - by
27812782
naxes = len(columns)
2782-
2783-
if ax is None:
2784-
nrows, ncols = _get_layout(naxes, layout=layout)
2785-
fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes,
2786-
sharex=True, sharey=True,
2787-
figsize=figsize, ax=ax)
2788-
else:
2789-
if naxes > 1:
2790-
raise ValueError("Using an existing axis is not supported when plotting multiple columns.")
2791-
fig = ax.get_figure()
2792-
axes = ax.get_axes()
2783+
nrows, ncols = _get_layout(naxes, layout=layout)
2784+
fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes,
2785+
sharex=True, sharey=True,
2786+
figsize=figsize, ax=ax)
27932787

27942788
ravel_axes = _flatten(axes)
27952789

@@ -2974,12 +2968,6 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=
29742968
if subplot_kw is None:
29752969
subplot_kw = {}
29762970

2977-
if ax is None:
2978-
fig = plt.figure(**fig_kw)
2979-
else:
2980-
fig = ax.get_figure()
2981-
fig.clear()
2982-
29832971
# Create empty object array to hold all axes. It's easiest to make it 1-d
29842972
# so we can just append subplots upon creation, and then
29852973
nplots = nrows * ncols
@@ -2989,6 +2977,21 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=
29892977
elif nplots < naxes:
29902978
raise ValueError("naxes {0} is larger than layour size defined by nrows * ncols".format(naxes))
29912979

2980+
if ax is None:
2981+
fig = plt.figure(**fig_kw)
2982+
else:
2983+
fig = ax.get_figure()
2984+
# if ax is passed and a number of subplots is 1, return ax as it is
2985+
if naxes == 1:
2986+
if squeeze:
2987+
return fig, ax
2988+
else:
2989+
return fig, _flatten(ax)
2990+
else:
2991+
warnings.warn("To output multiple subplots, the figure containing the passed axes "
2992+
"is being cleared", UserWarning)
2993+
fig.clear()
2994+
29922995
axarr = np.empty(nplots, dtype=object)
29932996

29942997
def on_right(i):

0 commit comments

Comments
 (0)