@@ -12,7 +12,8 @@ def hist(data, column, by=None, ax=None, fontsize=None):
12
12
ax .set_xticklabels (keys , rotation = 0 , fontsize = fontsize )
13
13
return ax
14
14
15
- def boxplot (data , column , by = None , ax = None , fontsize = None , rot = 0 ):
15
+ def boxplot (data , column = None , by = None , ax = None , fontsize = None ,
16
+ rot = 0 , grid = True ):
16
17
"""
17
18
Make a box plot from DataFrame column optionally grouped by some columns or
18
19
other inputs
@@ -30,19 +31,38 @@ def boxplot(data, column, by=None, ax=None, fontsize=None, rot=0):
30
31
-------
31
32
ax : matplotlib.axes.AxesSubplot
32
33
"""
33
- keys , values = zip (* data .groupby (by )[column ])
34
+ def plot_group (grouped , ax ):
35
+ keys , values = zip (* grouped )
36
+ keys = [_stringify (x ) for x in keys ]
37
+ ax .boxplot (values )
38
+ ax .set_xticklabels (keys , rotation = rot , fontsize = fontsize )
34
39
35
- if ax is None :
36
- ax = plt .gca ()
37
- ax .boxplot (values )
38
- ax .set_xticklabels (keys , rotation = rot , fontsize = fontsize )
40
+ if by is not None :
41
+ if not isinstance (by , (list , tuple )):
42
+ by = [by ]
39
43
40
- ax .set_xlabel (str (by ))
41
- ax .set_ylabel (str (column ))
44
+ columns = None if column is None else [column ]
45
+ fig , axes = _grouped_plot_by_column (plot_group , data , columns = columns ,
46
+ by = by )
47
+ ax = axes
48
+ else :
49
+ if ax is None :
50
+ ax = plt .gca ()
42
51
43
- plt .subplots_adjust (bottom = 0.15 )
52
+ data = data ._get_numeric_data ()
53
+ keys = [_stringify (x ) for x in data .columns ]
54
+ ax .boxplot (list (data .values .T ))
55
+ ax .set_xticklabels (keys , rotation = rot , fontsize = fontsize )
56
+
57
+ plt .subplots_adjust (bottom = 0.15 , top = 0.9 , left = 0.1 , right = 0.9 , wspace = 0.1 )
44
58
return ax
45
59
60
+ def _stringify (x ):
61
+ if isinstance (x , tuple ):
62
+ return '|' .join (str (y ) for y in x )
63
+ else :
64
+ return str (x )
65
+
46
66
def scatter_plot (data , x , y , by = None , ax = None ):
47
67
"""
48
68
@@ -66,7 +86,7 @@ def plot_group(group, ax):
66
86
67
87
return fig
68
88
69
- def _grouped_plot (plotf , data , by = None ):
89
+ def _grouped_plot (plotf , data , by = None , numeric_only = True ):
70
90
grouped = data .groupby (by )
71
91
ngroups = len (grouped )
72
92
@@ -80,10 +100,40 @@ def _grouped_plot(plotf, data, by=None):
80
100
81
101
for i , (key , group ) in enumerate (grouped ):
82
102
ax = ravel_axes [i ]
103
+ if numeric_only :
104
+ group = group ._get_numeric_data ()
83
105
plotf (group , ax )
84
106
ax .set_title (str (key ))
85
107
86
- return fig
108
+ return fig , axes
109
+
110
+ def _grouped_plot_by_column (plotf , data , columns = None , by = None ,
111
+ numeric_only = True ):
112
+ grouped = data .groupby (by )
113
+ if columns is None :
114
+ columns = data .columns - by
115
+ ngroups = len (columns )
116
+
117
+ nrows , ncols = _get_layout (ngroups )
118
+ fig , axes = plt .subplots (nrows = nrows , ncols = ncols ,
119
+ sharex = True , sharey = True )
120
+
121
+ if isinstance (axes , plt .Axes ):
122
+ ravel_axes = [axes ]
123
+ else :
124
+ ravel_axes = []
125
+ for row in axes :
126
+ ravel_axes .extend (row )
127
+
128
+ for i , col in enumerate (columns ):
129
+ ax = ravel_axes [i ]
130
+ gp_col = grouped [col ]
131
+ plotf (gp_col , ax )
132
+ ax .set_title (col )
133
+
134
+ fig .suptitle ('Boxplot grouped by %s' % by )
135
+
136
+ return fig , axes
87
137
88
138
def _get_layout (nplots ):
89
139
if nplots == 1 :
0 commit comments