From 2336fb343a1c6d189b084e585683421da1b09968 Mon Sep 17 00:00:00 2001 From: Chris Whelan Date: Mon, 19 Jan 2015 14:26:02 -0800 Subject: [PATCH] Fix plotting memory leak and add regression test --- doc/source/whatsnew/v0.16.1.txt | 2 +- pandas/tests/test_graphics.py | 29 ++++ pandas/tools/plotting.py | 272 ++++++++++++++++++-------------- 3 files changed, 185 insertions(+), 118 deletions(-) diff --git a/doc/source/whatsnew/v0.16.1.txt b/doc/source/whatsnew/v0.16.1.txt index 7166801b3fbf0..d36f094ae00cd 100755 --- a/doc/source/whatsnew/v0.16.1.txt +++ b/doc/source/whatsnew/v0.16.1.txt @@ -141,7 +141,6 @@ Bug Fixes - - Bug in unequal comparisons between categorical data and a scalar, which was not in the categories (e.g. ``Series(Categorical(list("abc"), ordered=True)) > "d"``. This returned ``False`` for all elements, but now raises a ``TypeError``. Equality comparisons also now return ``False`` for ``==`` and ``True`` for ``!=``. (:issue:`9848`) - Bug in DataFrame ``__setitem__`` when right hand side is a dictionary (:issue:`9874`) - Bug in ``where`` when dtype is ``datetime64/timedelta64``, but dtype of other is not (:issue:`9804`) @@ -164,3 +163,4 @@ Bug Fixes - Fixed latex output for multi-indexed dataframes (:issue:`9778`) - Bug causing an exception when setting an empty range using ``DataFrame.loc`` (:issue:`9596`) +- Fixed memory leak in ``AreaPlot`` and ``LinePlot`` that prevented calls to ``plt.close()`` from having any effect. (:issue:`9003`) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 7ec57c0304530..638d4bae8e7d5 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -28,6 +28,8 @@ from numpy.testing import assert_array_equal, assert_allclose from numpy.testing.decorators import slow import pandas.tools.plotting as plotting +import weakref +import gc def _skip_if_mpl_14_or_dev_boxplot(): @@ -3390,6 +3392,33 @@ def test_sharey_and_ax(self): "y label is invisible but shouldn't") + def test_memory_leak(self): + """ Check that every plot type gets properly collected. """ + import matplotlib.pyplot as plt + results = {} + for kind in plotting._plot_klass.keys(): + args = {} + if kind in ['hexbin', 'scatter', 'pie']: + df = self.hexbin_df + args = {'x': 'A', 'y': 'B'} + elif kind == 'area': + df = self.tdf.abs() + else: + df = self.tdf + + # Use a weakref so we can see if the object gets collected without + # also preventing it from being collected + results[kind] = weakref.proxy(df.plot(kind=kind, **args)) + + # have matplotlib delete all the figures + plt.close('all') + # force a garbage collection + gc.collect() + for key in results: + # check that every plot was collected + with tm.assertRaises(ReferenceError): + # need to actually access something to get an error + results[key].lines @tm.mplskip class TestDataFrameGroupByPlots(TestPlotBase): diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 513f165af4686..f83e5cbd17368 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -738,6 +738,135 @@ def r(h): ax.grid() return ax +def _mplplot_plotf(errorbar=False): + import matplotlib.pyplot as plt + def plotf(ax, x, y, style=None, **kwds): + mask = com.isnull(y) + if mask.any(): + y = np.ma.array(y) + y = np.ma.masked_where(mask, y) + + if errorbar: + return plt.Axes.errorbar(ax, x, y, **kwds) + else: + # prevent style kwarg from going to errorbar, where it is unsupported + if style is not None: + args = (ax, x, y, style) + else: + args = (ax, x, y) + return plt.Axes.plot(*args, **kwds) + + return plotf + + +def _lineplot_plotf(f, stacked, subplots): + def plotf(ax, x, y, style=None, column_num=None, **kwds): + # column_num is used to get the target column from protf in line and area plots + if not hasattr(ax, '_pos_prior') or column_num == 0: + LinePlot._initialize_prior(ax, len(y)) + y_values = LinePlot._get_stacked_values(ax, y, kwds['label'], stacked) + lines = f(ax, x, y_values, style=style, **kwds) + LinePlot._update_prior(ax, y, stacked, subplots) + return lines + + return plotf + + +def _areaplot_plotf(f, stacked, subplots): + import matplotlib.pyplot as plt + def plotf(ax, x, y, style=None, column_num=None, **kwds): + if not hasattr(ax, '_pos_prior') or column_num == 0: + LinePlot._initialize_prior(ax, len(y)) + y_values = LinePlot._get_stacked_values(ax, y, kwds['label'], stacked) + lines = f(ax, x, y_values, style=style, **kwds) + + # get data from the line to get coordinates for fill_between + xdata, y_values = lines[0].get_data(orig=False) + + if (y >= 0).all(): + start = ax._pos_prior + elif (y <= 0).all(): + start = ax._neg_prior + else: + start = np.zeros(len(y)) + + if not 'color' in kwds: + kwds['color'] = lines[0].get_color() + + plt.Axes.fill_between(ax, xdata, start, y_values, **kwds) + LinePlot._update_prior(ax, y, stacked, subplots) + return lines + + return plotf + + +def _histplot_plotf(bins, bottom, stacked, subplots): + import matplotlib.pyplot as plt + def plotf(ax, y, style=None, column_num=None, **kwds): + if not hasattr(ax, '_pos_prior') or column_num == 0: + LinePlot._initialize_prior(ax, len(bins) - 1) + y = y[~com.isnull(y)] + new_bottom = ax._pos_prior + bottom + # ignore style + n, new_bins, patches = plt.Axes.hist(ax, y, bins=bins, + bottom=new_bottom, **kwds) + LinePlot._update_prior(ax, n, stacked, subplots) + return patches + + return plotf + + +def _boxplot_plotf(return_type): + def plotf(ax, y, column_num=None, **kwds): + if y.ndim == 2: + y = [remove_na(v) for v in y] + # Boxplot fails with empty arrays, so need to add a NaN + # if any cols are empty + # GH 8181 + y = [v if v.size > 0 else np.array([np.nan]) for v in y] + else: + y = remove_na(y) + bp = ax.boxplot(y, **kwds) + + if return_type == 'dict': + return bp, bp + elif return_type == 'both': + return BoxPlot.BP(ax=ax, lines=bp), bp + else: + return ax, bp + + return plotf + + +def _kdeplot_plotf(f, bw_method, ind): + from scipy.stats import gaussian_kde + from scipy import __version__ as spv + + def plotf(ax, y, style=None, column_num=None, **kwds): + y = remove_na(y) + if LooseVersion(spv) >= '0.11.0': + gkde = gaussian_kde(y, bw_method=bw_method) + else: + gkde = gaussian_kde(y) + if bw_method is not None: + msg = ('bw_method was added in Scipy 0.11.0.' + + ' Scipy version in use is %s.' % spv) + warnings.warn(msg) + + if ind is None: + sample_range = max(y) - min(y) + ind_local = np.linspace(min(y) - 0.5 * sample_range, + max(y) + 0.5 * sample_range, 1000) + else: + ind_local = ind + + y = gkde.evaluate(ind_local) + lines = f(ax, ind_local, y, style=style, **kwds) + return lines + + return plotf + + class MPLPlot(object): """ @@ -1194,28 +1323,15 @@ def _is_datetype(self): index.inferred_type in ('datetime', 'date', 'datetime64', 'time')) + def _plot_errors(self): + return any(e is not None for e in self.errors.values()) + def _get_plot_function(self): ''' Returns the matplotlib plotting function (plot or errorbar) based on the presence of errorbar keywords. ''' - errorbar = any(e is not None for e in self.errors.values()) - def plotf(ax, x, y, style=None, **kwds): - mask = com.isnull(y) - if mask.any(): - y = np.ma.array(y) - y = np.ma.masked_where(mask, y) - - if errorbar: - return self.plt.Axes.errorbar(ax, x, y, **kwds) - else: - # prevent style kwarg from going to errorbar, where it is unsupported - if style is not None: - args = (ax, x, y, style) - else: - args = (ax, x, y) - return self.plt.Axes.plot(*args, **kwds) - return plotf + return _mplplot_plotf(self._plot_errors()) def _get_index_name(self): if isinstance(self.data.index, MultiIndex): @@ -1594,7 +1710,6 @@ def _is_ts_plot(self): return not self.x_compat and self.use_index and self._use_dynamic_x() def _make_plot(self): - self._initialize_prior(len(self.data)) if self._is_ts_plot(): data = self._maybe_convert_index(self.data) @@ -1626,12 +1741,13 @@ def _make_plot(self): left, right = _get_xlim(lines) ax.set_xlim(left, right) - def _get_stacked_values(self, y, label): - if self.stacked: + @classmethod + def _get_stacked_values(cls, ax, y, label, stacked): + if stacked: if (y >= 0).all(): - return self._pos_prior + y + return ax._pos_prior + y elif (y <= 0).all(): - return self._neg_prior + y + return ax._neg_prior + y else: raise ValueError('When stacked is True, each column must be either all positive or negative.' '{0} contains both positive and negative values'.format(label)) @@ -1640,15 +1756,8 @@ def _get_stacked_values(self, y, label): def _get_plot_function(self): f = MPLPlot._get_plot_function(self) - def plotf(ax, x, y, style=None, column_num=None, **kwds): - # column_num is used to get the target column from protf in line and area plots - if column_num == 0: - self._initialize_prior(len(self.data)) - y_values = self._get_stacked_values(y, kwds['label']) - lines = f(ax, x, y_values, style=style, **kwds) - self._update_prior(y) - return lines - return plotf + + return _lineplot_plotf(f, self.stacked, self.subplots) def _get_ts_plot_function(self): from pandas.tseries.plotting import tsplot @@ -1660,19 +1769,21 @@ def _plot(ax, x, data, style=None, **kwds): return lines return _plot - def _initialize_prior(self, n): - self._pos_prior = np.zeros(n) - self._neg_prior = np.zeros(n) + @classmethod + def _initialize_prior(cls, ax, n): + ax._pos_prior = np.zeros(n) + ax._neg_prior = np.zeros(n) - def _update_prior(self, y): - if self.stacked and not self.subplots: + @classmethod + def _update_prior(cls, ax, y, stacked, subplots): + if stacked and not subplots: # tsplot resample may changedata length - if len(self._pos_prior) != len(y): - self._initialize_prior(len(y)) + if len(ax._pos_prior) != len(y): + cls._initialize_prior(ax, len(y)) if (y >= 0).all(): - self._pos_prior += y + ax._pos_prior += y elif (y <= 0).all(): - self._neg_prior += y + ax._neg_prior += y def _maybe_convert_index(self, data): # tsplot converts automatically, but don't want to convert index @@ -1736,28 +1847,8 @@ def _get_plot_function(self): raise ValueError("Log-y scales are not supported in area plot") else: f = MPLPlot._get_plot_function(self) - def plotf(ax, x, y, style=None, column_num=None, **kwds): - if column_num == 0: - self._initialize_prior(len(self.data)) - y_values = self._get_stacked_values(y, kwds['label']) - lines = f(ax, x, y_values, style=style, **kwds) - - # get data from the line to get coordinates for fill_between - xdata, y_values = lines[0].get_data(orig=False) - - if (y >= 0).all(): - start = self._pos_prior - elif (y <= 0).all(): - start = self._neg_prior - else: - start = np.zeros(len(y)) - if not 'color' in kwds: - kwds['color'] = lines[0].get_color() - - self.plt.Axes.fill_between(ax, xdata, start, y_values, **kwds) - self._update_prior(y) - return lines + return _areaplot_plotf(f, self.stacked, self.subplots) return plotf @@ -1943,17 +2034,7 @@ def _args_adjust(self): self.bottom = np.array(self.bottom) def _get_plot_function(self): - def plotf(ax, y, style=None, column_num=None, **kwds): - if column_num == 0: - self._initialize_prior(len(self.bins) - 1) - y = y[~com.isnull(y)] - bottom = self._pos_prior + self.bottom - # ignore style - n, bins, patches = self.plt.Axes.hist(ax, y, bins=self.bins, - bottom=bottom, **kwds) - self._update_prior(n) - return patches - return plotf + return _histplot_plotf(self.bins, self.bottom, self.stacked, self.subplots) def _make_plot(self): plotf = self._get_plot_function() @@ -2000,35 +2081,9 @@ def __init__(self, data, bw_method=None, ind=None, **kwargs): def _args_adjust(self): pass - def _get_ind(self, y): - if self.ind is None: - sample_range = max(y) - min(y) - ind = np.linspace(min(y) - 0.5 * sample_range, - max(y) + 0.5 * sample_range, 1000) - else: - ind = self.ind - return ind - def _get_plot_function(self): - from scipy.stats import gaussian_kde - from scipy import __version__ as spv f = MPLPlot._get_plot_function(self) - def plotf(ax, y, style=None, column_num=None, **kwds): - y = remove_na(y) - if LooseVersion(spv) >= '0.11.0': - gkde = gaussian_kde(y, bw_method=self.bw_method) - else: - gkde = gaussian_kde(y) - if self.bw_method is not None: - msg = ('bw_method was added in Scipy 0.11.0.' + - ' Scipy version in use is %s.' % spv) - warnings.warn(msg) - - ind = self._get_ind(y) - y = gkde.evaluate(ind) - lines = f(ax, ind, y, style=style, **kwds) - return lines - return plotf + return _kdeplot_plotf(f, self.bw_method, self.ind) def _post_plot_logic(self): for ax in self.axes: @@ -2123,24 +2178,7 @@ def _args_adjust(self): self.sharey = False def _get_plot_function(self): - def plotf(ax, y, column_num=None, **kwds): - if y.ndim == 2: - y = [remove_na(v) for v in y] - # Boxplot fails with empty arrays, so need to add a NaN - # if any cols are empty - # GH 8181 - y = [v if v.size > 0 else np.array([np.nan]) for v in y] - else: - y = remove_na(y) - bp = ax.boxplot(y, **kwds) - - if self.return_type == 'dict': - return bp, bp - elif self.return_type == 'both': - return self.BP(ax=ax, lines=bp), bp - else: - return ax, bp - return plotf + return _boxplot_plotf(self.return_type) def _validate_color_args(self): if 'color' in self.kwds: