Skip to content

Commit f3a301d

Browse files
committed
Merge pull request pandas-dev#4021 from cpcloud/hist-figure-arg-fix
BUG: allow series to use gcf-style figures
2 parents 1a67a8f + 0977398 commit f3a301d

File tree

4 files changed

+71
-41
lines changed

4 files changed

+71
-41
lines changed

doc/source/release.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ pandas 0.12
291291
- Fixed failing tests in test_yahoo, test_google where symbols were not
292292
retrieved but were being accessed (:issue:`3982`, :issue:`3985`,
293293
:issue:`4028`, :issue:`4054`)
294-
294+
- ``Series.hist`` will now take the figure from the current environment if
295+
one is not passed
295296

296297
pandas 0.11.0
297298
=============

doc/source/v0.12.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ Bug Fixes
434434
- Fixed failing tests in test_yahoo, test_google where symbols were not
435435
retrieved but were being accessed (:issue:`3982`, :issue:`3985`,
436436
:issue:`4028`, :issue:`4054`)
437+
- ``Series.hist`` will now take the figure from the current environment if
438+
one is not passed
437439

438440
See the :ref:`full release notes
439441
<release>` or issue tracker

pandas/tests/test_graphics.py

+55-23
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def _skip_if_no_scipy():
2626

2727

2828
class TestSeriesPlots(unittest.TestCase):
29-
3029
@classmethod
3130
def setUpClass(cls):
3231
try:
@@ -45,6 +44,10 @@ def setUp(self):
4544
self.iseries = tm.makePeriodSeries()
4645
self.iseries.name = 'iseries'
4746

47+
def tearDown(self):
48+
import matplotlib.pyplot as plt
49+
plt.close('all')
50+
4851
@slow
4952
def test_plot(self):
5053
_check_plot_works(self.ts.plot, label='foo')
@@ -178,6 +181,19 @@ def test_hist(self):
178181
_check_plot_works(self.ts.hist, figsize=(8, 10))
179182
_check_plot_works(self.ts.hist, by=self.ts.index.month)
180183

184+
import matplotlib.pyplot as plt
185+
fig, ax = plt.subplots(1, 1)
186+
_check_plot_works(self.ts.hist, ax=ax)
187+
_check_plot_works(self.ts.hist, ax=ax, figure=fig)
188+
_check_plot_works(self.ts.hist, figure=fig)
189+
plt.close('all')
190+
191+
fig, (ax1, ax2) = plt.subplots(1, 2)
192+
_check_plot_works(self.ts.hist, figure=fig, ax=ax1)
193+
_check_plot_works(self.ts.hist, figure=fig, ax=ax2)
194+
self.assertRaises(ValueError, self.ts.hist, by=self.ts.index,
195+
figure=fig)
196+
181197
def test_plot_fails_when_ax_differs_from_figure(self):
182198
from pylab import figure
183199
fig1 = figure()
@@ -196,11 +212,10 @@ def test_kde(self):
196212
@slow
197213
def test_kde_color(self):
198214
_skip_if_no_scipy()
199-
_check_plot_works(self.ts.plot, kind='kde')
200-
_check_plot_works(self.ts.plot, kind='density')
201215
ax = self.ts.plot(kind='kde', logy=True, color='r')
202-
self.assert_(ax.get_lines()[0].get_color() == 'r')
203-
self.assert_(ax.get_lines()[1].get_color() == 'r')
216+
lines = ax.get_lines()
217+
self.assertEqual(len(lines), 1)
218+
self.assertEqual(lines[0].get_color(), 'r')
204219

205220
@slow
206221
def test_autocorrelation_plot(self):
@@ -228,7 +243,6 @@ def test_invalid_plot_data(self):
228243

229244
@slow
230245
def test_valid_object_plot(self):
231-
from pandas.io.common import PerformanceWarning
232246
s = Series(range(10), dtype=object)
233247
kinds = 'line', 'bar', 'barh', 'kde', 'density'
234248

@@ -262,6 +276,10 @@ def setUpClass(cls):
262276
except ImportError:
263277
raise nose.SkipTest
264278

279+
def tearDown(self):
280+
import matplotlib.pyplot as plt
281+
plt.close('all')
282+
265283
@slow
266284
def test_plot(self):
267285
df = tm.makeTimeDataFrame()
@@ -804,19 +822,18 @@ def test_invalid_kind(self):
804822

805823

806824
class TestDataFrameGroupByPlots(unittest.TestCase):
807-
808825
@classmethod
809826
def setUpClass(cls):
810-
# import sys
811-
# if 'IPython' in sys.modules:
812-
# raise nose.SkipTest
813-
814827
try:
815828
import matplotlib as mpl
816829
mpl.use('Agg', warn=False)
817830
except ImportError:
818831
raise nose.SkipTest
819832

833+
def tearDown(self):
834+
import matplotlib.pyplot as plt
835+
plt.close('all')
836+
820837
@slow
821838
def test_boxplot(self):
822839
df = DataFrame(np.random.rand(10, 2), columns=['Col1', 'Col2'])
@@ -906,12 +923,6 @@ def test_grouped_hist(self):
906923
by=df.C, foo='bar')
907924

908925
def test_option_mpl_style(self):
909-
# just a sanity check
910-
try:
911-
import matplotlib
912-
except:
913-
raise nose.SkipTest
914-
915926
set_option('display.mpl_style', 'default')
916927
set_option('display.mpl_style', None)
917928
set_option('display.mpl_style', False)
@@ -925,22 +936,43 @@ def test_invalid_colormap(self):
925936

926937
self.assertRaises(ValueError, df.plot, colormap='invalid_colormap')
927938

939+
940+
def assert_is_valid_plot_return_object(objs):
941+
import matplotlib.pyplot as plt
942+
if isinstance(objs, np.ndarray):
943+
for el in objs.flat:
944+
assert isinstance(el, plt.Axes), ('one of \'objs\' is not a '
945+
'matplotlib Axes instance, '
946+
'type encountered {0!r}'
947+
''.format(el.__class__.__name__))
948+
else:
949+
assert isinstance(objs, (plt.Artist, tuple, dict)), \
950+
('objs is neither an ndarray of Artist instances nor a '
951+
'single Artist instance, tuple, or dict, "objs" is a {0!r} '
952+
''.format(objs.__class__.__name__))
953+
954+
928955
def _check_plot_works(f, *args, **kwargs):
929956
import matplotlib.pyplot as plt
930957

931-
fig = plt.gcf()
958+
try:
959+
fig = kwargs['figure']
960+
except KeyError:
961+
fig = plt.gcf()
932962
plt.clf()
933-
ax = fig.add_subplot(211)
963+
ax = kwargs.get('ax', fig.add_subplot(211))
934964
ret = f(*args, **kwargs)
935-
assert ret is not None # do something more intelligent
936965

937-
ax = fig.add_subplot(212)
966+
assert ret is not None
967+
assert_is_valid_plot_return_object(ret)
968+
938969
try:
939-
kwargs['ax'] = ax
970+
kwargs['ax'] = fig.add_subplot(212)
940971
ret = f(*args, **kwargs)
941-
assert(ret is not None) # do something more intelligent
942972
except Exception:
943973
pass
974+
else:
975+
assert_is_valid_plot_return_object(ret)
944976

945977
with ensure_clean() as path:
946978
plt.savefig(path)

pandas/tools/plotting.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,6 @@ def radviz(frame, class_column, ax=None, colormap=None, **kwds):
339339
"""
340340
import matplotlib.pyplot as plt
341341
import matplotlib.patches as patches
342-
import matplotlib.text as text
343-
import random
344342

345343
def normalize(series):
346344
a = min(series)
@@ -378,10 +376,8 @@ def normalize(series):
378376
to_plot[class_name][1].append(y[1])
379377

380378
for i, class_ in enumerate(classes):
381-
line = ax.scatter(to_plot[class_][0],
382-
to_plot[class_][1],
383-
color=colors[i],
384-
label=com.pprint_thing(class_), **kwds)
379+
ax.scatter(to_plot[class_][0], to_plot[class_][1], color=colors[i],
380+
label=com.pprint_thing(class_), **kwds)
385381
ax.legend()
386382

387383
ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
@@ -429,7 +425,6 @@ def andrews_curves(data, class_column, ax=None, samples=200, colormap=None,
429425
"""
430426
from math import sqrt, pi, sin, cos
431427
import matplotlib.pyplot as plt
432-
import random
433428

434429
def function(amplitudes):
435430
def f(x):
@@ -445,9 +440,7 @@ def f(x):
445440
return result
446441
return f
447442

448-
449443
n = len(data)
450-
classes = set(data[class_column])
451444
class_col = data[class_column]
452445
columns = [data[col] for col in data.columns if (col != class_column)]
453446
x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
@@ -492,7 +485,6 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
492485
fig: matplotlib figure
493486
"""
494487
import random
495-
import matplotlib
496488
import matplotlib.pyplot as plt
497489

498490
# random.sample(ndarray, int) fails on python 3.3, sigh
@@ -576,7 +568,6 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
576568
>>> plt.show()
577569
"""
578570
import matplotlib.pyplot as plt
579-
import random
580571

581572

582573
n = len(data)
@@ -1240,7 +1231,6 @@ def _use_dynamic_x(self):
12401231
return (freq is not None) and self._is_dynamic_freq(freq)
12411232

12421233
def _make_plot(self):
1243-
import pandas.tseries.plotting as tsplot
12441234
# this is slightly deceptive
12451235
if not self.x_compat and self.use_index and self._use_dynamic_x():
12461236
data = self._maybe_convert_index(self.data)
@@ -2021,20 +2011,26 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
20212011
"""
20222012
import matplotlib.pyplot as plt
20232013

2024-
fig = kwds.setdefault('figure', plt.figure(figsize=figsize))
2014+
fig = kwds.get('figure', plt.gcf()
2015+
if plt.get_fignums() else plt.figure(figsize=figsize))
2016+
if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
2017+
fig.set_size_inches(*figsize, forward=True)
20252018

20262019
if by is None:
20272020
if ax is None:
20282021
ax = fig.add_subplot(111)
2029-
else:
2030-
if ax.get_figure() != fig:
2031-
raise AssertionError('passed axis not bound to passed figure')
2022+
if ax.get_figure() != fig:
2023+
raise AssertionError('passed axis not bound to passed figure')
20322024
values = self.dropna().values
20332025

20342026
ax.hist(values, **kwds)
20352027
ax.grid(grid)
20362028
axes = np.array([ax])
20372029
else:
2030+
if 'figure' in kwds:
2031+
raise ValueError("Cannot pass 'figure' when using the "
2032+
"'by' argument, since a new 'Figure' instance "
2033+
"will be created")
20382034
axes = grouped_hist(self, by=by, ax=ax, grid=grid, figsize=figsize,
20392035
**kwds)
20402036

@@ -2384,7 +2380,6 @@ def on_right(i):
23842380

23852381

23862382
def _get_xlim(lines):
2387-
import pandas.tseries.converter as conv
23882383
left, right = np.inf, -np.inf
23892384
for l in lines:
23902385
x = l.get_xdata()

0 commit comments

Comments
 (0)