Skip to content

Commit 57205c4

Browse files
committedJul 6, 2014
Merge pull request #7391 from sinhrks/subplotax
BUG: inconsistent subplot ax handling
·
v3.0.0.dev0v0.14.1
2 parents 790d646 + bacebce commit 57205c4

File tree

3 files changed

+57
-39
lines changed

3 files changed

+57
-39
lines changed
 

‎doc/source/v0.14.1.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ Bug Fixes
173173
- Bug in ``to_timedelta`` that accepted invalid units and misinterpreted 'm/h' (:issue:`7611`, :issue: `6423`)
174174

175175
- Bug in grouped ``hist`` and ``scatter`` plots use old ``figsize`` default (:issue:`7394`)
176+
- Bug in plotting subplots with ``DataFrame.plot``, ``hist`` clears passed ``ax`` even if the number of subplots is one (:issue:`7391`).
177+
- Bug in plotting subplots with ``DataFrame.boxplot`` with ``by`` kw raises ``ValueError`` if the number of subplots exceeds 1 (:issue:`7391`).
176178

177179
- Bug in ``Panel.apply`` with a multi-index as an axis (:issue:`7469`)
178180

‎pandas/tests/test_graphics.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,13 @@ def test_plot(self):
883883
axes = _check_plot_works(df.plot, kind='bar', subplots=True)
884884
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
885885

886+
# When ax is supplied and required number of axes is 1,
887+
# passed ax should be used:
888+
fig, ax = self.plt.subplots()
889+
axes = df.plot(kind='bar', subplots=True, ax=ax)
890+
self.assertEqual(len(axes), 1)
891+
self.assertIs(ax.get_axes(), axes[0])
892+
886893
def test_nonnumeric_exclude(self):
887894
df = DataFrame({'A': ["x", "y", "z"], 'B': [1, 2, 3]})
888895
ax = df.plot()
@@ -1443,17 +1450,23 @@ def test_boxplot(self):
14431450

14441451
df = DataFrame(np.random.rand(10, 2), columns=['Col1', 'Col2'])
14451452
df['X'] = Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B'])
1453+
df['Y'] = Series(['A'] * 10)
14461454
_check_plot_works(df.boxplot, by='X')
14471455

1448-
# When ax is supplied, existing axes should be used:
1456+
# When ax is supplied and required number of axes is 1,
1457+
# passed ax should be used:
14491458
fig, ax = self.plt.subplots()
14501459
axes = df.boxplot('Col1', by='X', ax=ax)
14511460
self.assertIs(ax.get_axes(), axes)
14521461

1453-
# Multiple columns with an ax argument is not supported
14541462
fig, ax = self.plt.subplots()
1455-
with tm.assertRaisesRegexp(ValueError, 'existing axis'):
1456-
df.boxplot(column=['Col1', 'Col2'], by='X', ax=ax)
1463+
axes = df.groupby('Y').boxplot(ax=ax, return_type='axes')
1464+
self.assertIs(ax.get_axes(), axes['A'])
1465+
1466+
# Multiple columns with an ax argument should use same figure
1467+
fig, ax = self.plt.subplots()
1468+
axes = df.boxplot(column=['Col1', 'Col2'], by='X', ax=ax, return_type='axes')
1469+
self.assertIs(axes['Col1'].get_figure(), fig)
14571470

14581471
# When by is None, check that all relevant lines are present in the dict
14591472
fig, ax = self.plt.subplots()
@@ -2204,32 +2217,32 @@ class TestDataFrameGroupByPlots(TestPlotBase):
22042217
@slow
22052218
def test_boxplot(self):
22062219
grouped = self.hist_df.groupby(by='gender')
2207-
box = _check_plot_works(grouped.boxplot, return_type='dict')
2208-
self._check_axes_shape(self.plt.gcf().axes, axes_num=2, layout=(1, 2))
2220+
axes = _check_plot_works(grouped.boxplot, return_type='axes')
2221+
self._check_axes_shape(axes.values(), axes_num=2, layout=(1, 2))
22092222

2210-
box = _check_plot_works(grouped.boxplot, subplots=False,
2211-
return_type='dict')
2212-
self._check_axes_shape(self.plt.gcf().axes, axes_num=2, layout=(1, 2))
2223+
axes = _check_plot_works(grouped.boxplot, subplots=False,
2224+
return_type='axes')
2225+
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
22132226

22142227
tuples = lzip(string.ascii_letters[:10], range(10))
22152228
df = DataFrame(np.random.rand(10, 3),
22162229
index=MultiIndex.from_tuples(tuples))
22172230

22182231
grouped = df.groupby(level=1)
2219-
box = _check_plot_works(grouped.boxplot, return_type='dict')
2220-
self._check_axes_shape(self.plt.gcf().axes, axes_num=10, layout=(4, 3))
2232+
axes = _check_plot_works(grouped.boxplot, return_type='axes')
2233+
self._check_axes_shape(axes.values(), axes_num=10, layout=(4, 3))
22212234

2222-
box = _check_plot_works(grouped.boxplot, subplots=False,
2223-
return_type='dict')
2224-
self._check_axes_shape(self.plt.gcf().axes, axes_num=10, layout=(4, 3))
2235+
axes = _check_plot_works(grouped.boxplot, subplots=False,
2236+
return_type='axes')
2237+
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
22252238

22262239
grouped = df.unstack(level=1).groupby(level=0, axis=1)
2227-
box = _check_plot_works(grouped.boxplot, return_type='dict')
2228-
self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(2, 2))
2240+
axes = _check_plot_works(grouped.boxplot, return_type='axes')
2241+
self._check_axes_shape(axes.values(), axes_num=3, layout=(2, 2))
22292242

2230-
box = _check_plot_works(grouped.boxplot, subplots=False,
2231-
return_type='dict')
2232-
self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(2, 2))
2243+
axes = _check_plot_works(grouped.boxplot, subplots=False,
2244+
return_type='axes')
2245+
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
22332246

22342247
def test_series_plot_color_kwargs(self):
22352248
# GH1890

‎pandas/tools/plotting.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2667,7 +2667,8 @@ def plot_group(group, ax):
26672667

26682668

26692669
def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
2670-
rot=0, grid=True, figsize=None, layout=None, **kwds):
2670+
rot=0, grid=True, ax=None, figsize=None,
2671+
layout=None, **kwds):
26712672
"""
26722673
Make box plots from DataFrameGroupBy data.
26732674
@@ -2714,7 +2715,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
27142715
naxes = len(grouped)
27152716
nrows, ncols = _get_layout(naxes, layout=layout)
27162717
fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False,
2717-
sharex=False, sharey=True)
2718+
ax=ax, sharex=False, sharey=True, figsize=figsize)
27182719
axes = _flatten(axes)
27192720

27202721
ret = compat.OrderedDict()
@@ -2735,7 +2736,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
27352736
else:
27362737
df = frames[0]
27372738
ret = df.boxplot(column=column, fontsize=fontsize, rot=rot,
2738-
grid=grid, figsize=figsize, layout=layout, **kwds)
2739+
grid=grid, ax=ax, figsize=figsize, layout=layout, **kwds)
27392740
return ret
27402741

27412742

@@ -2781,17 +2782,10 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
27812782
by = [by]
27822783
columns = data._get_numeric_data().columns - by
27832784
naxes = len(columns)
2784-
2785-
if ax is None:
2786-
nrows, ncols = _get_layout(naxes, layout=layout)
2787-
fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes,
2788-
sharex=True, sharey=True,
2789-
figsize=figsize, ax=ax)
2790-
else:
2791-
if naxes > 1:
2792-
raise ValueError("Using an existing axis is not supported when plotting multiple columns.")
2793-
fig = ax.get_figure()
2794-
axes = ax.get_axes()
2785+
nrows, ncols = _get_layout(naxes, layout=layout)
2786+
fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes,
2787+
sharex=True, sharey=True,
2788+
figsize=figsize, ax=ax)
27952789

27962790
ravel_axes = _flatten(axes)
27972791

@@ -2976,12 +2970,6 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=
29762970
if subplot_kw is None:
29772971
subplot_kw = {}
29782972

2979-
if ax is None:
2980-
fig = plt.figure(**fig_kw)
2981-
else:
2982-
fig = ax.get_figure()
2983-
fig.clear()
2984-
29852973
# Create empty object array to hold all axes. It's easiest to make it 1-d
29862974
# so we can just append subplots upon creation, and then
29872975
nplots = nrows * ncols
@@ -2991,6 +2979,21 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=
29912979
elif nplots < naxes:
29922980
raise ValueError("naxes {0} is larger than layour size defined by nrows * ncols".format(naxes))
29932981

2982+
if ax is None:
2983+
fig = plt.figure(**fig_kw)
2984+
else:
2985+
fig = ax.get_figure()
2986+
# if ax is passed and a number of subplots is 1, return ax as it is
2987+
if naxes == 1:
2988+
if squeeze:
2989+
return fig, ax
2990+
else:
2991+
return fig, _flatten(ax)
2992+
else:
2993+
warnings.warn("To output multiple subplots, the figure containing the passed axes "
2994+
"is being cleared", UserWarning)
2995+
fig.clear()
2996+
29942997
axarr = np.empty(nplots, dtype=object)
29952998

29962999
def on_right(i):

0 commit comments

Comments
 (0)
Please sign in to comment.