Skip to content

Commit cdc1652

Browse files
committed
Handle boxplot arguments column and grid properly.
1 parent 7b2827b commit cdc1652

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

pandas/tools/plotting.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def boxplot(data, column=None, by=None, ax=None, fontsize=None,
2121
Parameters
2222
----------
2323
data : DataFrame
24-
column : column names or list of names, or vector
24+
column : column name or list of names, or vector
2525
Can be any valid input to groupby
2626
by : string or sequence
2727
Column in the DataFrame to group by
@@ -37,22 +37,34 @@ def plot_group(grouped, ax):
3737
ax.boxplot(values)
3838
ax.set_xticklabels(keys, rotation=rot, fontsize=fontsize)
3939

40+
if column == None:
41+
columns = None
42+
else:
43+
if isinstance(column, (list, tuple)):
44+
columns = column
45+
else:
46+
columns = [column]
47+
4048
if by is not None:
4149
if not isinstance(by, (list, tuple)):
4250
by = [by]
4351

44-
columns = None if column is None else [column]
4552
fig, axes = _grouped_plot_by_column(plot_group, data, columns=columns,
46-
by=by)
53+
by=by, grid=grid)
4754
ax = axes
4855
else:
4956
if ax is None:
5057
ax = plt.gca()
5158

5259
data = data._get_numeric_data()
53-
keys = [_stringify(x) for x in data.columns]
54-
ax.boxplot(list(data.values.T))
60+
if columns:
61+
cols = columns
62+
else:
63+
cols = data.columns
64+
keys = [_stringify(x) for x in cols]
65+
ax.boxplot(list(data[cols].values.T))
5566
ax.set_xticklabels(keys, rotation=rot, fontsize=fontsize)
67+
ax.grid(grid)
5668

5769
plt.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.1)
5870
return ax
@@ -108,7 +120,7 @@ def _grouped_plot(plotf, data, by=None, numeric_only=True):
108120
return fig, axes
109121

110122
def _grouped_plot_by_column(plotf, data, columns=None, by=None,
111-
numeric_only=True):
123+
numeric_only=True, grid=False):
112124
grouped = data.groupby(by)
113125
if columns is None:
114126
columns = data._get_numeric_data().columns - by
@@ -123,13 +135,17 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
123135
else:
124136
ravel_axes = []
125137
for row in axes:
126-
ravel_axes.extend(row)
138+
if isinstance(row, plt.Axes):
139+
ravel_axes.append(row)
140+
else:
141+
ravel_axes.extend(row)
127142

128143
for i, col in enumerate(columns):
129144
ax = ravel_axes[i]
130145
gp_col = grouped[col]
131146
plotf(gp_col, ax)
132147
ax.set_title(col)
148+
ax.grid(grid)
133149

134150
fig.suptitle('Boxplot grouped by %s' % by)
135151

0 commit comments

Comments
 (0)