diff --git a/doc/source/release.rst b/doc/source/release.rst index 114b5d749c85c..1edb44502221c 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -291,7 +291,8 @@ pandas 0.12 - Fixed failing tests in test_yahoo, test_google where symbols were not retrieved but were being accessed (:issue:`3982`, :issue:`3985`, :issue:`4028`, :issue:`4054`) - + - ``Series.hist`` will now take the figure from the current environment if + one is not passed pandas 0.11.0 ============= diff --git a/doc/source/v0.12.0.txt b/doc/source/v0.12.0.txt index 203982a4e8c93..0d2251bf225d9 100644 --- a/doc/source/v0.12.0.txt +++ b/doc/source/v0.12.0.txt @@ -434,6 +434,8 @@ Bug Fixes - Fixed failing tests in test_yahoo, test_google where symbols were not retrieved but were being accessed (:issue:`3982`, :issue:`3985`, :issue:`4028`, :issue:`4054`) + - ``Series.hist`` will now take the figure from the current environment if + one is not passed See the :ref:`full release notes ` or issue tracker diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index d094e8b99d9cb..fe793275627e0 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -26,7 +26,6 @@ def _skip_if_no_scipy(): class TestSeriesPlots(unittest.TestCase): - @classmethod def setUpClass(cls): try: @@ -45,6 +44,10 @@ def setUp(self): self.iseries = tm.makePeriodSeries() self.iseries.name = 'iseries' + def tearDown(self): + import matplotlib.pyplot as plt + plt.close('all') + @slow def test_plot(self): _check_plot_works(self.ts.plot, label='foo') @@ -178,6 +181,19 @@ def test_hist(self): _check_plot_works(self.ts.hist, figsize=(8, 10)) _check_plot_works(self.ts.hist, by=self.ts.index.month) + import matplotlib.pyplot as plt + fig, ax = plt.subplots(1, 1) + _check_plot_works(self.ts.hist, ax=ax) + _check_plot_works(self.ts.hist, ax=ax, figure=fig) + _check_plot_works(self.ts.hist, figure=fig) + plt.close('all') + + fig, (ax1, ax2) = plt.subplots(1, 2) + _check_plot_works(self.ts.hist, figure=fig, ax=ax1) + _check_plot_works(self.ts.hist, figure=fig, ax=ax2) + self.assertRaises(ValueError, self.ts.hist, by=self.ts.index, + figure=fig) + def test_plot_fails_when_ax_differs_from_figure(self): from pylab import figure fig1 = figure() @@ -196,11 +212,10 @@ def test_kde(self): @slow def test_kde_color(self): _skip_if_no_scipy() - _check_plot_works(self.ts.plot, kind='kde') - _check_plot_works(self.ts.plot, kind='density') ax = self.ts.plot(kind='kde', logy=True, color='r') - self.assert_(ax.get_lines()[0].get_color() == 'r') - self.assert_(ax.get_lines()[1].get_color() == 'r') + lines = ax.get_lines() + self.assertEqual(len(lines), 1) + self.assertEqual(lines[0].get_color(), 'r') @slow def test_autocorrelation_plot(self): @@ -228,7 +243,6 @@ def test_invalid_plot_data(self): @slow def test_valid_object_plot(self): - from pandas.io.common import PerformanceWarning s = Series(range(10), dtype=object) kinds = 'line', 'bar', 'barh', 'kde', 'density' @@ -262,6 +276,10 @@ def setUpClass(cls): except ImportError: raise nose.SkipTest + def tearDown(self): + import matplotlib.pyplot as plt + plt.close('all') + @slow def test_plot(self): df = tm.makeTimeDataFrame() @@ -804,19 +822,18 @@ def test_invalid_kind(self): class TestDataFrameGroupByPlots(unittest.TestCase): - @classmethod def setUpClass(cls): - # import sys - # if 'IPython' in sys.modules: - # raise nose.SkipTest - try: import matplotlib as mpl mpl.use('Agg', warn=False) except ImportError: raise nose.SkipTest + def tearDown(self): + import matplotlib.pyplot as plt + plt.close('all') + @slow def test_boxplot(self): df = DataFrame(np.random.rand(10, 2), columns=['Col1', 'Col2']) @@ -906,12 +923,6 @@ def test_grouped_hist(self): by=df.C, foo='bar') def test_option_mpl_style(self): - # just a sanity check - try: - import matplotlib - except: - raise nose.SkipTest - set_option('display.mpl_style', 'default') set_option('display.mpl_style', None) set_option('display.mpl_style', False) @@ -925,22 +936,43 @@ def test_invalid_colormap(self): self.assertRaises(ValueError, df.plot, colormap='invalid_colormap') + +def assert_is_valid_plot_return_object(objs): + import matplotlib.pyplot as plt + if isinstance(objs, np.ndarray): + for el in objs.flat: + assert isinstance(el, plt.Axes), ('one of \'objs\' is not a ' + 'matplotlib Axes instance, ' + 'type encountered {0!r}' + ''.format(el.__class__.__name__)) + else: + assert isinstance(objs, (plt.Artist, tuple, dict)), \ + ('objs is neither an ndarray of Artist instances nor a ' + 'single Artist instance, tuple, or dict, "objs" is a {0!r} ' + ''.format(objs.__class__.__name__)) + + def _check_plot_works(f, *args, **kwargs): import matplotlib.pyplot as plt - fig = plt.gcf() + try: + fig = kwargs['figure'] + except KeyError: + fig = plt.gcf() plt.clf() - ax = fig.add_subplot(211) + ax = kwargs.get('ax', fig.add_subplot(211)) ret = f(*args, **kwargs) - assert ret is not None # do something more intelligent - ax = fig.add_subplot(212) + assert ret is not None + assert_is_valid_plot_return_object(ret) + try: - kwargs['ax'] = ax + kwargs['ax'] = fig.add_subplot(212) ret = f(*args, **kwargs) - assert(ret is not None) # do something more intelligent except Exception: pass + else: + assert_is_valid_plot_return_object(ret) with ensure_clean() as path: plt.savefig(path) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 8abe9df5ddd56..2ed9d2f607ea9 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -339,8 +339,6 @@ def radviz(frame, class_column, ax=None, colormap=None, **kwds): """ import matplotlib.pyplot as plt import matplotlib.patches as patches - import matplotlib.text as text - import random def normalize(series): a = min(series) @@ -378,10 +376,8 @@ def normalize(series): to_plot[class_name][1].append(y[1]) for i, class_ in enumerate(classes): - line = ax.scatter(to_plot[class_][0], - to_plot[class_][1], - color=colors[i], - label=com.pprint_thing(class_), **kwds) + ax.scatter(to_plot[class_][0], to_plot[class_][1], color=colors[i], + label=com.pprint_thing(class_), **kwds) ax.legend() ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none')) @@ -429,7 +425,6 @@ def andrews_curves(data, class_column, ax=None, samples=200, colormap=None, """ from math import sqrt, pi, sin, cos import matplotlib.pyplot as plt - import random def function(amplitudes): def f(x): @@ -445,9 +440,7 @@ def f(x): return result return f - n = len(data) - classes = set(data[class_column]) class_col = data[class_column] columns = [data[col] for col in data.columns if (col != class_column)] x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)] @@ -492,7 +485,6 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds): fig: matplotlib figure """ import random - import matplotlib import matplotlib.pyplot as plt # random.sample(ndarray, int) fails on python 3.3, sigh @@ -576,7 +568,6 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None, >>> plt.show() """ import matplotlib.pyplot as plt - import random n = len(data) @@ -1240,7 +1231,6 @@ def _use_dynamic_x(self): return (freq is not None) and self._is_dynamic_freq(freq) def _make_plot(self): - import pandas.tseries.plotting as tsplot # this is slightly deceptive if not self.x_compat and self.use_index and self._use_dynamic_x(): data = self._maybe_convert_index(self.data) @@ -2021,20 +2011,26 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None, """ import matplotlib.pyplot as plt - fig = kwds.setdefault('figure', plt.figure(figsize=figsize)) + fig = kwds.get('figure', plt.gcf() + if plt.get_fignums() else plt.figure(figsize=figsize)) + if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()): + fig.set_size_inches(*figsize, forward=True) if by is None: if ax is None: ax = fig.add_subplot(111) - else: - if ax.get_figure() != fig: - raise AssertionError('passed axis not bound to passed figure') + if ax.get_figure() != fig: + raise AssertionError('passed axis not bound to passed figure') values = self.dropna().values ax.hist(values, **kwds) ax.grid(grid) axes = np.array([ax]) else: + if 'figure' in kwds: + raise ValueError("Cannot pass 'figure' when using the " + "'by' argument, since a new 'Figure' instance " + "will be created") axes = grouped_hist(self, by=by, ax=ax, grid=grid, figsize=figsize, **kwds) @@ -2384,7 +2380,6 @@ def on_right(i): def _get_xlim(lines): - import pandas.tseries.converter as conv left, right = np.inf, -np.inf for l in lines: x = l.get_xdata()