@@ -21,7 +21,7 @@ def boxplot(data, column=None, by=None, ax=None, fontsize=None,
21
21
Parameters
22
22
----------
23
23
data : DataFrame
24
- column : column names or list of names, or vector
24
+ column : column name or list of names, or vector
25
25
Can be any valid input to groupby
26
26
by : string or sequence
27
27
Column in the DataFrame to group by
@@ -37,22 +37,34 @@ def plot_group(grouped, ax):
37
37
ax .boxplot (values )
38
38
ax .set_xticklabels (keys , rotation = rot , fontsize = fontsize )
39
39
40
+ if column == None :
41
+ columns = None
42
+ else :
43
+ if isinstance (column , (list , tuple )):
44
+ columns = column
45
+ else :
46
+ columns = [column ]
47
+
40
48
if by is not None :
41
49
if not isinstance (by , (list , tuple )):
42
50
by = [by ]
43
51
44
- columns = None if column is None else [column ]
45
52
fig , axes = _grouped_plot_by_column (plot_group , data , columns = columns ,
46
- by = by )
53
+ by = by , grid = grid )
47
54
ax = axes
48
55
else :
49
56
if ax is None :
50
57
ax = plt .gca ()
51
58
52
59
data = data ._get_numeric_data ()
53
- keys = [_stringify (x ) for x in data .columns ]
54
- ax .boxplot (list (data .values .T ))
60
+ if columns :
61
+ cols = columns
62
+ else :
63
+ cols = data .columns
64
+ keys = [_stringify (x ) for x in cols ]
65
+ ax .boxplot (list (data [cols ].values .T ))
55
66
ax .set_xticklabels (keys , rotation = rot , fontsize = fontsize )
67
+ ax .grid (grid )
56
68
57
69
plt .subplots_adjust (bottom = 0.15 , top = 0.9 , left = 0.1 , right = 0.9 , wspace = 0.1 )
58
70
return ax
@@ -108,7 +120,7 @@ def _grouped_plot(plotf, data, by=None, numeric_only=True):
108
120
return fig , axes
109
121
110
122
def _grouped_plot_by_column (plotf , data , columns = None , by = None ,
111
- numeric_only = True ):
123
+ numeric_only = True , grid = False ):
112
124
grouped = data .groupby (by )
113
125
if columns is None :
114
126
columns = data ._get_numeric_data ().columns - by
@@ -123,13 +135,17 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
123
135
else :
124
136
ravel_axes = []
125
137
for row in axes :
126
- ravel_axes .extend (row )
138
+ if isinstance (row , plt .Axes ):
139
+ ravel_axes .append (row )
140
+ else :
141
+ ravel_axes .extend (row )
127
142
128
143
for i , col in enumerate (columns ):
129
144
ax = ravel_axes [i ]
130
145
gp_col = grouped [col ]
131
146
plotf (gp_col , ax )
132
147
ax .set_title (col )
148
+ ax .grid (grid )
133
149
134
150
fig .suptitle ('Boxplot grouped by %s' % by )
135
151
0 commit comments