@@ -19,7 +19,8 @@ def hist(data, column, by=None, ax=None, fontsize=None):
19
19
ax .set_xticklabels (keys , rotation = 0 , fontsize = fontsize )
20
20
return ax
21
21
22
- def grouped_hist (data , column , by = None , ax = None , bins = 50 , log = False ):
22
+ def grouped_hist (data , column , by = None , ax = None , bins = 50 , log = False ,
23
+ figsize = None ):
23
24
"""
24
25
25
26
Returns
@@ -29,14 +30,14 @@ def grouped_hist(data, column, by=None, ax=None, bins=50, log=False):
29
30
def plot_group (group , ax ):
30
31
ax .hist (group [column ].dropna (), bins = bins )
31
32
fig = _grouped_plot (plot_group , data , by = by , sharex = False ,
32
- sharey = False )
33
+ sharey = False , figsize = figsize )
33
34
fig .subplots_adjust (bottom = 0.15 , top = 0.9 , left = 0.1 , right = 0.9 ,
34
35
hspace = 0.3 , wspace = 0.2 )
35
36
return fig
36
37
37
38
38
39
def boxplot (data , column = None , by = None , ax = None , fontsize = None ,
39
- rot = 0 , grid = True ):
40
+ rot = 0 , grid = True , figsize = None ):
40
41
"""
41
42
Make a box plot from DataFrame column optionally grouped by some columns or
42
43
other inputs
@@ -73,7 +74,7 @@ def plot_group(grouped, ax):
73
74
by = [by ]
74
75
75
76
fig , axes = _grouped_plot_by_column (plot_group , data , columns = columns ,
76
- by = by , grid = grid )
77
+ by = by , grid = grid , figsize = figsize )
77
78
ax = axes
78
79
else :
79
80
if ax is None :
@@ -98,7 +99,7 @@ def _stringify(x):
98
99
else :
99
100
return str (x )
100
101
101
- def scatter_plot (data , x , y , by = None , ax = None ):
102
+ def scatter_plot (data , x , y , by = None , ax = None , figsize = None ):
102
103
"""
103
104
104
105
Returns
@@ -113,7 +114,7 @@ def plot_group(group, ax):
113
114
ax .scatter (xvals , yvals )
114
115
115
116
if by is not None :
116
- fig = _grouped_plot (plot_group , data , by = by )
117
+ fig = _grouped_plot (plot_group , data , by = by , figsize = figsize )
117
118
else :
118
119
fig = plt .figure ()
119
120
ax = fig .add_subplot (111 )
@@ -123,14 +124,24 @@ def plot_group(group, ax):
123
124
124
125
return fig
125
126
126
- def _grouped_plot (plotf , data , by = None , numeric_only = True , figsize = ( 10 , 5 ) ,
127
+ def _grouped_plot (plotf , data , by = None , numeric_only = True , figsize = None ,
127
128
sharex = True , sharey = True ):
129
+ import matplotlib .pyplot as plt
130
+
131
+ # allow to specify mpl default with 'default'
132
+ if not (isinstance (figsize , str ) and figsize == 'default' ):
133
+ figsize = (10 , 5 ) # our default
134
+
128
135
grouped = data .groupby (by )
129
136
ngroups = len (grouped )
130
137
131
138
nrows , ncols = _get_layout (ngroups )
132
- fig , axes = subplots (nrows = nrows , ncols = ncols , figsize = figsize ,
133
- sharex = sharex , sharey = sharey )
139
+ if figsize is None :
140
+ # our favorite default beating matplotlib's idea of the
141
+ # default size
142
+ figsize = (10 , 5 )
143
+ fig , axes = plt .subplots (nrows = nrows , ncols = ncols , figsize = figsize ,
144
+ sharex = sharex , sharey = sharey )
134
145
135
146
ravel_axes = []
136
147
for row in axes :
@@ -146,7 +157,8 @@ def _grouped_plot(plotf, data, by=None, numeric_only=True, figsize=(10, 5),
146
157
return fig , axes
147
158
148
159
def _grouped_plot_by_column (plotf , data , columns = None , by = None ,
149
- numeric_only = True , grid = False ):
160
+ numeric_only = True , grid = False ,
161
+ figsize = None ):
150
162
import matplotlib .pyplot as plt
151
163
152
164
grouped = data .groupby (by )
@@ -155,8 +167,9 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
155
167
ngroups = len (columns )
156
168
157
169
nrows , ncols = _get_layout (ngroups )
158
- fig , axes = subplots (nrows = nrows , ncols = ncols ,
159
- sharex = True , sharey = True )
170
+ fig , axes = plt .subplots (nrows = nrows , ncols = ncols ,
171
+ sharex = True , sharey = True ,
172
+ figsize = figsize )
160
173
161
174
if isinstance (axes , plt .Axes ):
162
175
ravel_axes = [axes ]
0 commit comments