Skip to content

Commit bb52860

Browse files
yarikopticwesm
authored andcommitted
ENH: pass figsize into _grouped_plot functions
atm figsize is not passed through by boxplot etc, making it (impossible?) to have custom figsize Conflicts: pandas/tools/plotting.py
1 parent 0d4dc67 commit bb52860

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

pandas/tools/plotting.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def hist(data, column, by=None, ax=None, fontsize=None):
1919
ax.set_xticklabels(keys, rotation=0, fontsize=fontsize)
2020
return ax
2121

22-
def grouped_hist(data, column, by=None, ax=None, bins=50, log=False):
22+
def grouped_hist(data, column, by=None, ax=None, bins=50, log=False,
23+
figsize=None):
2324
"""
2425
2526
Returns
@@ -29,14 +30,14 @@ def grouped_hist(data, column, by=None, ax=None, bins=50, log=False):
2930
def plot_group(group, ax):
3031
ax.hist(group[column].dropna(), bins=bins)
3132
fig = _grouped_plot(plot_group, data, by=by, sharex=False,
32-
sharey=False)
33+
sharey=False, figsize=figsize)
3334
fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9,
3435
hspace=0.3, wspace=0.2)
3536
return fig
3637

3738

3839
def boxplot(data, column=None, by=None, ax=None, fontsize=None,
39-
rot=0, grid=True):
40+
rot=0, grid=True, figsize=None):
4041
"""
4142
Make a box plot from DataFrame column optionally grouped by some columns or
4243
other inputs
@@ -73,7 +74,7 @@ def plot_group(grouped, ax):
7374
by = [by]
7475

7576
fig, axes = _grouped_plot_by_column(plot_group, data, columns=columns,
76-
by=by, grid=grid)
77+
by=by, grid=grid, figsize=figsize)
7778
ax = axes
7879
else:
7980
if ax is None:
@@ -98,7 +99,7 @@ def _stringify(x):
9899
else:
99100
return str(x)
100101

101-
def scatter_plot(data, x, y, by=None, ax=None):
102+
def scatter_plot(data, x, y, by=None, ax=None, figsize=None):
102103
"""
103104
104105
Returns
@@ -113,7 +114,7 @@ def plot_group(group, ax):
113114
ax.scatter(xvals, yvals)
114115

115116
if by is not None:
116-
fig = _grouped_plot(plot_group, data, by=by)
117+
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize)
117118
else:
118119
fig = plt.figure()
119120
ax = fig.add_subplot(111)
@@ -123,14 +124,24 @@ def plot_group(group, ax):
123124

124125
return fig
125126

126-
def _grouped_plot(plotf, data, by=None, numeric_only=True, figsize=(10, 5),
127+
def _grouped_plot(plotf, data, by=None, numeric_only=True, figsize=None,
127128
sharex=True, sharey=True):
129+
import matplotlib.pyplot as plt
130+
131+
# allow to specify mpl default with 'default'
132+
if not (isinstance(figsize, str) and figsize == 'default'):
133+
figsize = (10, 5) # our default
134+
128135
grouped = data.groupby(by)
129136
ngroups = len(grouped)
130137

131138
nrows, ncols = _get_layout(ngroups)
132-
fig, axes = subplots(nrows=nrows, ncols=ncols, figsize=figsize,
133-
sharex=sharex, sharey=sharey)
139+
if figsize is None:
140+
# our favorite default beating matplotlib's idea of the
141+
# default size
142+
figsize = (10, 5)
143+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize,
144+
sharex=sharex, sharey=sharey)
134145

135146
ravel_axes = []
136147
for row in axes:
@@ -146,7 +157,8 @@ def _grouped_plot(plotf, data, by=None, numeric_only=True, figsize=(10, 5),
146157
return fig, axes
147158

148159
def _grouped_plot_by_column(plotf, data, columns=None, by=None,
149-
numeric_only=True, grid=False):
160+
numeric_only=True, grid=False,
161+
figsize=None):
150162
import matplotlib.pyplot as plt
151163

152164
grouped = data.groupby(by)
@@ -155,8 +167,9 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
155167
ngroups = len(columns)
156168

157169
nrows, ncols = _get_layout(ngroups)
158-
fig, axes = subplots(nrows=nrows, ncols=ncols,
159-
sharex=True, sharey=True)
170+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
171+
sharex=True, sharey=True,
172+
figsize=figsize)
160173

161174
if isinstance(axes, plt.Axes):
162175
ravel_axes = [axes]

0 commit comments

Comments
 (0)