Skip to content

Commit 7a2db77

Browse files
committed
Merge pull request #7351 from sinhrks/boxcln
CLN: Simplify boxplot and tests
2 parents f1c5386 + de69d62 commit 7a2db77

File tree

2 files changed

+109
-139
lines changed

2 files changed

+109
-139
lines changed

pandas/tests/test_graphics.py

+70-75
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,54 @@ def _check_has_errorbars(self, axes, xerr=0, yerr=0):
356356
self.assertEqual(xerr, xerr_count)
357357
self.assertEqual(yerr, yerr_count)
358358

359+
def _check_box_return_type(self, returned, return_type, expected_keys=None):
360+
"""
361+
Check box returned type is correct
362+
363+
Parameters
364+
----------
365+
returned : object to be tested, returned from boxplot
366+
return_type : str
367+
return_type passed to boxplot
368+
expected_keys : list-like, optional
369+
group labels in subplot case. If not passed,
370+
the function checks assuming boxplot uses single ax
371+
"""
372+
from matplotlib.axes import Axes
373+
types = {'dict': dict, 'axes': Axes, 'both': tuple}
374+
if expected_keys is None:
375+
# should be fixed when the returning default is changed
376+
if return_type is None:
377+
return_type = 'dict'
378+
379+
self.assertTrue(isinstance(returned, types[return_type]))
380+
if return_type == 'both':
381+
self.assertIsInstance(returned.ax, Axes)
382+
self.assertIsInstance(returned.lines, dict)
383+
else:
384+
# should be fixed when the returning default is changed
385+
if return_type is None:
386+
for r in self._flatten_visible(returned):
387+
self.assertIsInstance(r, Axes)
388+
return
389+
390+
self.assertTrue(isinstance(returned, OrderedDict))
391+
self.assertEqual(sorted(returned.keys()), sorted(expected_keys))
392+
for key, value in iteritems(returned):
393+
self.assertTrue(isinstance(value, types[return_type]))
394+
# check returned dict has correct mapping
395+
if return_type == 'axes':
396+
self.assertEqual(value.get_title(), key)
397+
elif return_type == 'both':
398+
self.assertEqual(value.ax.get_title(), key)
399+
self.assertIsInstance(value.ax, Axes)
400+
self.assertIsInstance(value.lines, dict)
401+
elif return_type == 'dict':
402+
line = value['medians'][0]
403+
self.assertEqual(line.get_axes().get_title(), key)
404+
else:
405+
raise AssertionError
406+
359407

360408
@tm.mplskip
361409
class TestSeriesPlots(TestPlotBase):
@@ -1421,65 +1469,20 @@ def test_boxplot_return_type(self):
14211469

14221470
with tm.assert_produces_warning(FutureWarning):
14231471
result = df.boxplot()
1424-
self.assertIsInstance(result, dict) # change to Axes in future
1472+
# change to Axes in future
1473+
self._check_box_return_type(result, 'dict')
14251474

14261475
with tm.assert_produces_warning(False):
14271476
result = df.boxplot(return_type='dict')
1428-
self.assertIsInstance(result, dict)
1477+
self._check_box_return_type(result, 'dict')
14291478

14301479
with tm.assert_produces_warning(False):
14311480
result = df.boxplot(return_type='axes')
1432-
self.assertIsInstance(result, mpl.axes.Axes)
1481+
self._check_box_return_type(result, 'axes')
14331482

14341483
with tm.assert_produces_warning(False):
14351484
result = df.boxplot(return_type='both')
1436-
self.assertIsInstance(result, tuple)
1437-
1438-
@slow
1439-
def test_boxplot_return_type_by(self):
1440-
import matplotlib as mpl
1441-
1442-
df = DataFrame(np.random.randn(10, 2))
1443-
df['g'] = ['a'] * 5 + ['b'] * 5
1444-
1445-
# old style: return_type=None
1446-
result = df.boxplot(by='g')
1447-
self.assertIsInstance(result, np.ndarray)
1448-
self.assertIsInstance(result[0], mpl.axes.Axes)
1449-
1450-
result = df.boxplot(by='g', return_type='dict')
1451-
self.assertIsInstance(result, dict)
1452-
self.assertIsInstance(result[0], dict)
1453-
1454-
result = df.boxplot(by='g', return_type='axes')
1455-
self.assertIsInstance(result, dict)
1456-
self.assertIsInstance(result[0], mpl.axes.Axes)
1457-
1458-
result = df.boxplot(by='g', return_type='both')
1459-
self.assertIsInstance(result, dict)
1460-
self.assertIsInstance(result[0], tuple)
1461-
self.assertIsInstance(result[0][0], mpl.axes.Axes)
1462-
self.assertIsInstance(result[0][1], dict)
1463-
1464-
# now for groupby
1465-
with tm.assert_produces_warning(FutureWarning):
1466-
result = df.groupby('g').boxplot()
1467-
self.assertIsInstance(result, dict)
1468-
self.assertIsInstance(result['a'], dict)
1469-
1470-
result = df.groupby('g').boxplot(return_type='dict')
1471-
self.assertIsInstance(result, dict)
1472-
self.assertIsInstance(result['a'], dict)
1473-
1474-
result = df.groupby('g').boxplot(return_type='axes')
1475-
self.assertIsInstance(result, dict)
1476-
self.assertIsInstance(result['a'], mpl.axes.Axes)
1477-
1478-
result = df.groupby('g').boxplot(return_type='both')
1479-
self.assertIsInstance(result, dict)
1480-
self.assertIsInstance(result['a'], tuple)
1481-
self.assertIsInstance(result['a'][0], mpl.axes.Axes)
1482-
self.assertIsInstance(result['a'][1], dict)
1485+
self._check_box_return_type(result, 'both')
14831486

14841487
@slow
14851488
def test_kde(self):
@@ -2278,47 +2281,39 @@ def test_grouped_hist(self):
22782281
with tm.assertRaises(AttributeError):
22792282
plotting.grouped_hist(df.A, by=df.C, foo='bar')
22802283

2281-
def _check_box_dict(self, returned, return_type,
2282-
expected_klass, expected_keys):
2283-
self.assertTrue(isinstance(returned, OrderedDict))
2284-
self.assertEqual(sorted(returned.keys()), sorted(expected_keys))
2285-
for key, value in iteritems(returned):
2286-
self.assertTrue(isinstance(value, expected_klass))
2287-
# check returned dict has correct mapping
2288-
if return_type == 'axes':
2289-
self.assertEqual(value.get_title(), key)
2290-
elif return_type == 'both':
2291-
self.assertEqual(value.ax.get_title(), key)
2292-
elif return_type == 'dict':
2293-
line = value['medians'][0]
2294-
self.assertEqual(line.get_axes().get_title(), key)
2295-
else:
2296-
raise AssertionError
2297-
22982284
@slow
22992285
def test_grouped_box_return_type(self):
2300-
import matplotlib.axes
2301-
23022286
df = self.hist_df
23032287

2288+
# old style: return_type=None
2289+
result = df.boxplot(by='gender')
2290+
self.assertIsInstance(result, np.ndarray)
2291+
self._check_box_return_type(result, None,
2292+
expected_keys=['height', 'weight', 'category'])
2293+
2294+
# now for groupby
2295+
with tm.assert_produces_warning(FutureWarning):
2296+
result = df.groupby('gender').boxplot()
2297+
self._check_box_return_type(result, 'dict', expected_keys=['Male', 'Female'])
2298+
23042299
columns2 = 'X B C D A G Y N Q O'.split()
23052300
df2 = DataFrame(random.randn(50, 10), columns=columns2)
23062301
categories2 = 'A B C D E F G H I J'.split()
23072302
df2['category'] = categories2 * 5
23082303

2309-
types = {'dict': dict, 'axes': matplotlib.axes.Axes, 'both': tuple}
2310-
for t, klass in iteritems(types):
2304+
for t in ['dict', 'axes', 'both']:
23112305
returned = df.groupby('classroom').boxplot(return_type=t)
2312-
self._check_box_dict(returned, t, klass, ['A', 'B', 'C'])
2306+
self._check_box_return_type(returned, t, expected_keys=['A', 'B', 'C'])
23132307

23142308
returned = df.boxplot(by='classroom', return_type=t)
2315-
self._check_box_dict(returned, t, klass, ['height', 'weight', 'category'])
2309+
self._check_box_return_type(returned, t,
2310+
expected_keys=['height', 'weight', 'category'])
23162311

23172312
returned = df2.groupby('category').boxplot(return_type=t)
2318-
self._check_box_dict(returned, t, klass, categories2)
2313+
self._check_box_return_type(returned, t, expected_keys=categories2)
23192314

23202315
returned = df2.boxplot(by='category', return_type=t)
2321-
self._check_box_dict(returned, t, klass, columns2)
2316+
self._check_box_return_type(returned, t, expected_keys=columns2)
23222317

23232318
@slow
23242319
def test_grouped_box_layout(self):

pandas/tools/plotting.py

+39-64
Original file line numberDiff line numberDiff line change
@@ -2323,13 +2323,11 @@ def boxplot(data, column=None, by=None, ax=None, fontsize=None,
23232323
if return_type not in valid_types:
23242324
raise ValueError("return_type")
23252325

2326-
23272326
from pandas import Series, DataFrame
23282327
if isinstance(data, Series):
23292328
data = DataFrame({'x': data})
23302329
column = 'x'
23312330

2332-
23332331
def _get_colors():
23342332
return _get_standard_colors(color=kwds.get('color'), num_colors=1)
23352333

@@ -2340,8 +2338,9 @@ def maybe_color_bp(bp):
23402338
setp(bp['whiskers'],color=colors[0],alpha=1)
23412339
setp(bp['medians'],color=colors[2],alpha=1)
23422340

2343-
def plot_group(grouped, ax):
2344-
keys, values = zip(*grouped)
2341+
BP = namedtuple("Boxplot", ['ax', 'lines']) # namedtuple to hold results
2342+
2343+
def plot_group(keys, values, ax):
23452344
keys = [com.pprint_thing(x) for x in keys]
23462345
values = [remove_na(v) for v in values]
23472346
bp = ax.boxplot(values, **kwds)
@@ -2350,7 +2349,14 @@ def plot_group(grouped, ax):
23502349
else:
23512350
ax.set_yticklabels(keys, rotation=rot, fontsize=fontsize)
23522351
maybe_color_bp(bp)
2353-
return bp
2352+
2353+
# Return axes in multiplot case, maybe revisit later # 985
2354+
if return_type == 'dict':
2355+
return bp
2356+
elif return_type == 'both':
2357+
return BP(ax=ax, lines=bp)
2358+
else:
2359+
return ax
23542360

23552361
colors = _get_colors()
23562362
if column is None:
@@ -2361,56 +2367,14 @@ def plot_group(grouped, ax):
23612367
else:
23622368
columns = [column]
23632369

2364-
BP = namedtuple("Boxplot", ['ax', 'lines']) # namedtuple to hold results
2365-
23662370
if by is not None:
2367-
fig, axes, d = _grouped_plot_by_column(plot_group, data, columns=columns,
2368-
by=by, grid=grid, figsize=figsize,
2369-
ax=ax, layout=layout)
2370-
2371-
# Return axes in multiplot case, maybe revisit later # 985
2372-
if return_type is None:
2373-
ret = axes
2374-
if return_type == 'axes':
2375-
ret = compat.OrderedDict()
2376-
axes = _flatten(axes)[:len(d)]
2377-
for k, ax in zip(d.keys(), axes):
2378-
ret[k] = ax
2379-
elif return_type == 'dict':
2380-
ret = d
2381-
elif return_type == 'both':
2382-
ret = compat.OrderedDict()
2383-
axes = _flatten(axes)[:len(d)]
2384-
for (k, line), ax in zip(d.items(), axes):
2385-
ret[k] = BP(ax=ax, lines=line)
2371+
result = _grouped_plot_by_column(plot_group, data, columns=columns,
2372+
by=by, grid=grid, figsize=figsize,
2373+
ax=ax, layout=layout, return_type=return_type)
23862374
else:
23872375
if layout is not None:
23882376
raise ValueError("The 'layout' keyword is not supported when "
23892377
"'by' is None")
2390-
if ax is None:
2391-
ax = _gca()
2392-
fig = ax.get_figure()
2393-
data = data._get_numeric_data()
2394-
if columns:
2395-
cols = columns
2396-
else:
2397-
cols = data.columns
2398-
keys = [com.pprint_thing(x) for x in cols]
2399-
2400-
# Return boxplot dict in single plot case
2401-
2402-
clean_values = [remove_na(x) for x in data[cols].values.T]
2403-
2404-
bp = ax.boxplot(clean_values, **kwds)
2405-
maybe_color_bp(bp)
2406-
2407-
if kwds.get('vert', 1):
2408-
ax.set_xticklabels(keys, rotation=rot, fontsize=fontsize)
2409-
else:
2410-
ax.set_yticklabels(keys, rotation=rot, fontsize=fontsize)
2411-
ax.grid(grid)
2412-
2413-
ret = ax
24142378

24152379
if return_type is None:
24162380
msg = ("\nThe default value for 'return_type' will change to "
@@ -2420,13 +2384,18 @@ def plot_group(grouped, ax):
24202384
"return_type='dict'.")
24212385
warnings.warn(msg, FutureWarning)
24222386
return_type = 'dict'
2423-
if return_type == 'dict':
2424-
ret = bp
2425-
elif return_type == 'both':
2426-
ret = BP(ax=ret, lines=bp)
2387+
if ax is None:
2388+
ax = _gca()
2389+
data = data._get_numeric_data()
2390+
if columns is None:
2391+
columns = data.columns
2392+
else:
2393+
data = data[columns]
24272394

2428-
fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
2429-
return ret
2395+
result = plot_group(columns, data.values.T, ax)
2396+
ax.grid(grid)
2397+
2398+
return result
24302399

24312400

24322401
def format_date_labels(ax, rot):
@@ -2734,7 +2703,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
27342703
if subplots is True:
27352704
naxes = len(grouped)
27362705
nrows, ncols = _get_layout(naxes, layout=layout)
2737-
_, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False,
2706+
fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False,
27382707
sharex=False, sharey=True)
27392708
axes = _flatten(axes)
27402709

@@ -2744,6 +2713,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
27442713
rot=rot, grid=grid, **kwds)
27452714
ax.set_title(com.pprint_thing(key))
27462715
ret[key] = d
2716+
fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
27472717
else:
27482718
from pandas.tools.merge import concat
27492719
keys, frames = zip(*grouped)
@@ -2795,9 +2765,8 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
27952765

27962766
def _grouped_plot_by_column(plotf, data, columns=None, by=None,
27972767
numeric_only=True, grid=False,
2798-
figsize=None, ax=None, layout=None, **kwargs):
2799-
from pandas.core.frame import DataFrame
2800-
2768+
figsize=None, ax=None, layout=None, return_type=None,
2769+
**kwargs):
28012770
grouped = data.groupby(by)
28022771
if columns is None:
28032772
if not isinstance(by, (list, tuple)):
@@ -2818,20 +2787,26 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
28182787

28192788
ravel_axes = _flatten(axes)
28202789

2821-
out_dict = compat.OrderedDict()
2790+
result = compat.OrderedDict()
28222791
for i, col in enumerate(columns):
28232792
ax = ravel_axes[i]
28242793
gp_col = grouped[col]
2825-
re_plotf = plotf(gp_col, ax, **kwargs)
2794+
keys, values = zip(*gp_col)
2795+
re_plotf = plotf(keys, values, ax, **kwargs)
28262796
ax.set_title(col)
28272797
ax.set_xlabel(com.pprint_thing(by))
2798+
result[col] = re_plotf
28282799
ax.grid(grid)
2829-
out_dict[col] = re_plotf
2800+
2801+
# Return axes in multiplot case, maybe revisit later # 985
2802+
if return_type is None:
2803+
result = axes
28302804

28312805
byline = by[0] if len(by) == 1 else by
28322806
fig.suptitle('Boxplot grouped by %s' % byline)
2807+
fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
28332808

2834-
return fig, axes, out_dict
2809+
return result
28352810

28362811

28372812
def table(ax, data, rowLabels=None, colLabels=None,

0 commit comments

Comments
 (0)