Skip to content

Commit 51cc9d9

Browse files
committed
Merge pull request #3842 from cpcloud/hist-figsize-3834
ENH: add figsize argument to DataFrame and Series hist methods
2 parents c2e12b4 + 94f1d22 commit 51cc9d9

File tree

4 files changed

+67
-39
lines changed

4 files changed

+67
-39
lines changed

RELEASE.rst

+3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pandas 0.11.1
7979
spurious plots from showing up.
8080
- Added Faq section on repr display options, to help users customize their setup.
8181
- ``where`` operations that result in block splitting are much faster (GH3733_)
82+
- Series and DataFrame hist methods now take a ``figsize`` argument (GH3834_)
8283

8384
**API Changes**
8485

@@ -312,6 +313,8 @@ pandas 0.11.1
312313
.. _GH3726: https://github.com/pydata/pandas/issues/3726
313314
.. _GH3795: https://github.com/pydata/pandas/issues/3795
314315
.. _GH3814: https://github.com/pydata/pandas/issues/3814
316+
.. _GH3834: https://github.com/pydata/pandas/issues/3834
317+
315318

316319
pandas 0.11.0
317320
=============

doc/source/v0.11.1.txt

+3
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ Enhancements
288288

289289
dff.groupby('B').filter(lambda x: len(x) > 2, dropna=False)
290290

291+
- Series and DataFrame hist methods now take a ``figsize`` argument (GH3834_)
292+
291293

292294
Bug Fixes
293295
~~~~~~~~~
@@ -396,3 +398,4 @@ on GitHub for a complete list.
396398
.. _GH3741: https://github.com/pydata/pandas/issues/3741
397399
.. _GH3726: https://github.com/pydata/pandas/issues/3726
398400
.. _GH3425: https://github.com/pydata/pandas/issues/3425
401+
.. _GH3834: https://github.com/pydata/pandas/issues/3834

pandas/tests/test_graphics.py

+35-23
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pandas import Series, DataFrame, MultiIndex, PeriodIndex, date_range
99
import pandas.util.testing as tm
1010
from pandas.util.testing import ensure_clean
11-
from pandas.core.config import set_option,get_option,config_prefix
11+
from pandas.core.config import set_option
1212

1313
import numpy as np
1414

@@ -28,11 +28,6 @@ class TestSeriesPlots(unittest.TestCase):
2828

2929
@classmethod
3030
def setUpClass(cls):
31-
import sys
32-
33-
# if 'IPython' in sys.modules:
34-
# raise nose.SkipTest
35-
3631
try:
3732
import matplotlib as mpl
3833
mpl.use('Agg', warn=False)
@@ -150,9 +145,16 @@ def test_irregular_datetime(self):
150145
def test_hist(self):
151146
_check_plot_works(self.ts.hist)
152147
_check_plot_works(self.ts.hist, grid=False)
153-
148+
_check_plot_works(self.ts.hist, figsize=(8, 10))
154149
_check_plot_works(self.ts.hist, by=self.ts.index.month)
155150

151+
def test_plot_fails_when_ax_differs_from_figure(self):
152+
from pylab import figure
153+
fig1 = figure()
154+
fig2 = figure()
155+
ax1 = fig1.add_subplot(111)
156+
self.assertRaises(AssertionError, self.ts.hist, ax=ax1, figure=fig2)
157+
156158
@slow
157159
def test_kde(self):
158160
_skip_if_no_scipy()
@@ -258,7 +260,8 @@ def test_plot(self):
258260
(u'\u03b4', 6),
259261
(u'\u03b4', 7)], names=['i0', 'i1'])
260262
columns = MultiIndex.from_tuples([('bar', u'\u0394'),
261-
('bar', u'\u0395')], names=['c0', 'c1'])
263+
('bar', u'\u0395')], names=['c0',
264+
'c1'])
262265
df = DataFrame(np.random.randint(0, 10, (8, 2)),
263266
columns=columns,
264267
index=index)
@@ -269,9 +272,9 @@ def test_nonnumeric_exclude(self):
269272
import matplotlib.pyplot as plt
270273
plt.close('all')
271274

272-
df = DataFrame({'A': ["x", "y", "z"], 'B': [1,2,3]})
275+
df = DataFrame({'A': ["x", "y", "z"], 'B': [1, 2, 3]})
273276
ax = df.plot()
274-
self.assert_(len(ax.get_lines()) == 1) #B was plotted
277+
self.assert_(len(ax.get_lines()) == 1) # B was plotted
275278

276279
@slow
277280
def test_label(self):
@@ -434,21 +437,24 @@ def test_bar_center(self):
434437
ax = df.plot(kind='bar', grid=True)
435438
self.assertEqual(ax.xaxis.get_ticklocs()[0],
436439
ax.patches[0].get_x() + ax.patches[0].get_width())
440+
437441
@slow
438442
def test_bar_log(self):
439443
# GH3254, GH3298 matplotlib/matplotlib#1882, #1892
440444
# regressions in 1.2.1
441445

442-
df = DataFrame({'A': [3] * 5, 'B': range(1,6)}, index=range(5))
443-
ax = df.plot(kind='bar', grid=True,log=True)
444-
self.assertEqual(ax.yaxis.get_ticklocs()[0],1.0)
446+
df = DataFrame({'A': [3] * 5, 'B': range(1, 6)}, index=range(5))
447+
ax = df.plot(kind='bar', grid=True, log=True)
448+
self.assertEqual(ax.yaxis.get_ticklocs()[0], 1.0)
445449

446-
p1 = Series([200,500]).plot(log=True,kind='bar')
447-
p2 = DataFrame([Series([200,300]),Series([300,500])]).plot(log=True,kind='bar',subplots=True)
450+
p1 = Series([200, 500]).plot(log=True, kind='bar')
451+
p2 = DataFrame([Series([200, 300]),
452+
Series([300, 500])]).plot(log=True, kind='bar',
453+
subplots=True)
448454

449-
(p1.yaxis.get_ticklocs() == np.array([ 0.625, 1.625]))
450-
(p2[0].yaxis.get_ticklocs() == np.array([ 1., 10., 100., 1000.])).all()
451-
(p2[1].yaxis.get_ticklocs() == np.array([ 1., 10., 100., 1000.])).all()
455+
(p1.yaxis.get_ticklocs() == np.array([0.625, 1.625]))
456+
(p2[0].yaxis.get_ticklocs() == np.array([1., 10., 100., 1000.])).all()
457+
(p2[1].yaxis.get_ticklocs() == np.array([1., 10., 100., 1000.])).all()
452458

453459
@slow
454460
def test_boxplot(self):
@@ -508,6 +514,9 @@ def test_hist(self):
508514
# make sure sharex, sharey is handled
509515
_check_plot_works(df.hist, sharex=True, sharey=True)
510516

517+
# handle figsize arg
518+
_check_plot_works(df.hist, figsize=(8, 10))
519+
511520
# make sure xlabelsize and xrot are handled
512521
ser = df[0]
513522
xf, yf = 20, 20
@@ -727,6 +736,7 @@ def test_invalid_kind(self):
727736
df = DataFrame(np.random.randn(10, 2))
728737
self.assertRaises(ValueError, df.plot, kind='aasdf')
729738

739+
730740
class TestDataFrameGroupByPlots(unittest.TestCase):
731741

732742
@classmethod
@@ -786,10 +796,10 @@ def test_time_series_plot_color_with_empty_kwargs(self):
786796

787797
plt.close('all')
788798
for i in range(3):
789-
ax = Series(np.arange(12) + 1, index=date_range(
790-
'1/1/2000', periods=12)).plot()
799+
ax = Series(np.arange(12) + 1, index=date_range('1/1/2000',
800+
periods=12)).plot()
791801

792-
line_colors = [ l.get_color() for l in ax.get_lines() ]
802+
line_colors = [l.get_color() for l in ax.get_lines()]
793803
self.assert_(line_colors == ['b', 'g', 'r'])
794804

795805
@slow
@@ -829,7 +839,6 @@ def test_grouped_hist(self):
829839
self.assertRaises(AttributeError, plotting.grouped_hist, df.A,
830840
by=df.C, foo='bar')
831841

832-
833842
def test_option_mpl_style(self):
834843
# just a sanity check
835844
try:
@@ -845,14 +854,15 @@ def test_option_mpl_style(self):
845854
except ValueError:
846855
pass
847856

857+
848858
def _check_plot_works(f, *args, **kwargs):
849859
import matplotlib.pyplot as plt
850860

851861
fig = plt.gcf()
852862
plt.clf()
853863
ax = fig.add_subplot(211)
854864
ret = f(*args, **kwargs)
855-
assert(ret is not None) # do something more intelligent
865+
assert ret is not None # do something more intelligent
856866

857867
ax = fig.add_subplot(212)
858868
try:
@@ -865,10 +875,12 @@ def _check_plot_works(f, *args, **kwargs):
865875
with ensure_clean() as path:
866876
plt.savefig(path)
867877

878+
868879
def curpath():
869880
pth, _ = os.path.split(os.path.abspath(__file__))
870881
return pth
871882

883+
872884
if __name__ == '__main__':
873885
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
874886
exit=False)

pandas/tools/plotting.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -658,9 +658,9 @@ def r(h):
658658
return ax
659659

660660

661-
def grouped_hist(data, column=None, by=None, ax=None, bins=50,
662-
figsize=None, layout=None, sharex=False, sharey=False,
663-
rot=90, grid=True, **kwargs):
661+
def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
662+
layout=None, sharex=False, sharey=False, rot=90, grid=True,
663+
**kwargs):
664664
"""
665665
Grouped histogram
666666
@@ -1839,10 +1839,9 @@ def plot_group(group, ax):
18391839
return fig
18401840

18411841

1842-
def hist_frame(
1843-
data, column=None, by=None, grid=True, xlabelsize=None, xrot=None,
1844-
ylabelsize=None, yrot=None, ax=None,
1845-
sharex=False, sharey=False, **kwds):
1842+
def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
1843+
xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
1844+
sharey=False, figsize=None, **kwds):
18461845
"""
18471846
Draw Histogram the DataFrame's series using matplotlib / pylab.
18481847
@@ -1866,17 +1865,20 @@ def hist_frame(
18661865
ax : matplotlib axes object, default None
18671866
sharex : bool, if True, the X axis will be shared amongst all subplots.
18681867
sharey : bool, if True, the Y axis will be shared amongst all subplots.
1868+
figsize : tuple
1869+
The size of the figure to create in inches by default
18691870
kwds : other plotting keyword arguments
18701871
To be passed to hist function
18711872
"""
18721873
if column is not None:
18731874
if not isinstance(column, (list, np.ndarray)):
18741875
column = [column]
1875-
data = data.ix[:, column]
1876+
data = data[column]
18761877

18771878
if by is not None:
18781879

1879-
axes = grouped_hist(data, by=by, ax=ax, grid=grid, **kwds)
1880+
axes = grouped_hist(data, by=by, ax=ax, grid=grid, figsize=figsize,
1881+
**kwds)
18801882

18811883
for ax in axes.ravel():
18821884
if xlabelsize is not None:
@@ -1898,11 +1900,11 @@ def hist_frame(
18981900
rows += 1
18991901
else:
19001902
cols += 1
1901-
_, axes = _subplots(nrows=rows, ncols=cols, ax=ax, squeeze=False,
1902-
sharex=sharex, sharey=sharey)
1903+
fig, axes = _subplots(nrows=rows, ncols=cols, ax=ax, squeeze=False,
1904+
sharex=sharex, sharey=sharey, figsize=figsize)
19031905

19041906
for i, col in enumerate(com._try_sort(data.columns)):
1905-
ax = axes[i / cols][i % cols]
1907+
ax = axes[i / cols, i % cols]
19061908
ax.xaxis.set_visible(True)
19071909
ax.yaxis.set_visible(True)
19081910
ax.hist(data[col].dropna().values, **kwds)
@@ -1922,13 +1924,13 @@ def hist_frame(
19221924
ax = axes[j / cols, j % cols]
19231925
ax.set_visible(False)
19241926

1925-
ax.get_figure().subplots_adjust(wspace=0.3, hspace=0.3)
1927+
fig.subplots_adjust(wspace=0.3, hspace=0.3)
19261928

19271929
return axes
19281930

19291931

19301932
def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
1931-
xrot=None, ylabelsize=None, yrot=None, **kwds):
1933+
xrot=None, ylabelsize=None, yrot=None, figsize=None, **kwds):
19321934
"""
19331935
Draw histogram of the input series using matplotlib
19341936
@@ -1948,6 +1950,8 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
19481950
If specified changes the y-axis label size
19491951
yrot : float, default None
19501952
rotation of y axis labels
1953+
figsize : tuple, default None
1954+
figure size in inches by default
19511955
kwds : keywords
19521956
To be passed to the actual plotting function
19531957
@@ -1958,16 +1962,22 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
19581962
"""
19591963
import matplotlib.pyplot as plt
19601964

1965+
fig = kwds.setdefault('figure', plt.figure(figsize=figsize))
1966+
19611967
if by is None:
19621968
if ax is None:
1963-
ax = plt.gca()
1969+
ax = fig.add_subplot(111)
1970+
else:
1971+
if ax.get_figure() != fig:
1972+
raise AssertionError('passed axis not bound to passed figure')
19641973
values = self.dropna().values
19651974

19661975
ax.hist(values, **kwds)
19671976
ax.grid(grid)
19681977
axes = np.array([ax])
19691978
else:
1970-
axes = grouped_hist(self, by=by, ax=ax, grid=grid, **kwds)
1979+
axes = grouped_hist(self, by=by, ax=ax, grid=grid, figsize=figsize,
1980+
**kwds)
19711981

19721982
for ax in axes.ravel():
19731983
if xlabelsize is not None:

0 commit comments

Comments
 (0)