Skip to content

Commit 795d566

Browse files
committed
Fix plotting memory leak and add regression test
1 parent 35d0893 commit 795d566

File tree

2 files changed

+184
-117
lines changed

2 files changed

+184
-117
lines changed

pandas/tests/test_graphics.py

+29
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from numpy.testing import assert_array_equal, assert_allclose
2929
from numpy.testing.decorators import slow
3030
import pandas.tools.plotting as plotting
31+
import weakref
32+
import gc
3133

3234

3335
def _skip_if_mpl_14_or_dev_boxplot():
@@ -3138,6 +3140,33 @@ def _check_errorbar_color(containers, expected, has_err='has_xerr'):
31383140
self._check_has_errorbars(ax, xerr=0, yerr=1)
31393141
_check_errorbar_color(ax.containers, 'green', has_err='has_yerr')
31403142

3143+
def test_memory_leak(self):
3144+
""" Check that every plot type gets properly collected. """
3145+
import matplotlib.pyplot as plt
3146+
results = {}
3147+
for kind in plotting._plot_klass.keys():
3148+
args = {}
3149+
if kind in ['hexbin', 'scatter', 'pie']:
3150+
df = self.hexbin_df
3151+
args = {'x': 'A', 'y': 'B'}
3152+
elif kind == 'area':
3153+
df = self.tdf.abs()
3154+
else:
3155+
df = self.tdf
3156+
3157+
# Use a weakref so we can see if the object gets collected without
3158+
# also preventing it from being collected
3159+
results[kind] = weakref.proxy(df.plot(kind=kind, **args))
3160+
3161+
# have matplotlib delete all the figures
3162+
plt.close('all')
3163+
# force a garbage collection
3164+
gc.collect()
3165+
for key in results:
3166+
# check that every plot was collected
3167+
with tm.assertRaises(ReferenceError):
3168+
# need to actually access something to get an error
3169+
results[key].lines
31413170

31423171
@tm.mplskip
31433172
class TestDataFrameGroupByPlots(TestPlotBase):

pandas/tools/plotting.py

+155-117
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,135 @@ def r(h):
750750
ax.grid()
751751
return ax
752752

753+
def _mplplot_plotf(errorbar=False):
754+
import matplotlib.pyplot as plt
755+
def plotf(ax, x, y, style=None, **kwds):
756+
mask = com.isnull(y)
757+
if mask.any():
758+
y = np.ma.array(y)
759+
y = np.ma.masked_where(mask, y)
760+
761+
if errorbar:
762+
return plt.Axes.errorbar(ax, x, y, **kwds)
763+
else:
764+
# prevent style kwarg from going to errorbar, where it is unsupported
765+
if style is not None:
766+
args = (ax, x, y, style)
767+
else:
768+
args = (ax, x, y)
769+
return plt.Axes.plot(*args, **kwds)
770+
771+
return plotf
772+
773+
774+
def _lineplot_plotf(f, stacked, subplots):
775+
def plotf(ax, x, y, style=None, column_num=None, **kwds):
776+
# column_num is used to get the target column from protf in line and area plots
777+
if not hasattr(ax, '_pos_prior') or column_num == 0:
778+
LinePlot._initialize_prior(ax, len(y))
779+
y_values = LinePlot._get_stacked_values(ax, y, kwds['label'], stacked)
780+
lines = f(ax, x, y_values, style=style, **kwds)
781+
LinePlot._update_prior(ax, y, stacked, subplots)
782+
return lines
783+
784+
return plotf
785+
786+
787+
def _areaplot_plotf(f, stacked, subplots):
788+
import matplotlib.pyplot as plt
789+
def plotf(ax, x, y, style=None, column_num=None, **kwds):
790+
if not hasattr(ax, '_pos_prior') or column_num == 0:
791+
LinePlot._initialize_prior(ax, len(y))
792+
y_values = LinePlot._get_stacked_values(ax, y, kwds['label'], stacked)
793+
lines = f(ax, x, y_values, style=style, **kwds)
794+
795+
# get data from the line to get coordinates for fill_between
796+
xdata, y_values = lines[0].get_data(orig=False)
797+
798+
if (y >= 0).all():
799+
start = ax._pos_prior
800+
elif (y <= 0).all():
801+
start = ax._neg_prior
802+
else:
803+
start = np.zeros(len(y))
804+
805+
if not 'color' in kwds:
806+
kwds['color'] = lines[0].get_color()
807+
808+
plt.Axes.fill_between(ax, xdata, start, y_values, **kwds)
809+
LinePlot._update_prior(ax, y, stacked, subplots)
810+
return lines
811+
812+
return plotf
813+
814+
815+
def _histplot_plotf(bins, bottom, stacked, subplots):
816+
import matplotlib.pyplot as plt
817+
def plotf(ax, y, style=None, column_num=None, **kwds):
818+
if not hasattr(ax, '_pos_prior') or column_num == 0:
819+
LinePlot._initialize_prior(ax, len(bins) - 1)
820+
y = y[~com.isnull(y)]
821+
new_bottom = ax._pos_prior + bottom
822+
# ignore style
823+
n, new_bins, patches = plt.Axes.hist(ax, y, bins=bins,
824+
bottom=new_bottom, **kwds)
825+
LinePlot._update_prior(ax, n, stacked, subplots)
826+
return patches
827+
828+
return plotf
829+
830+
831+
def _boxplot_plotf(return_type):
832+
def plotf(ax, y, column_num=None, **kwds):
833+
if y.ndim == 2:
834+
y = [remove_na(v) for v in y]
835+
# Boxplot fails with empty arrays, so need to add a NaN
836+
# if any cols are empty
837+
# GH 8181
838+
y = [v if v.size > 0 else np.array([np.nan]) for v in y]
839+
else:
840+
y = remove_na(y)
841+
bp = ax.boxplot(y, **kwds)
842+
843+
if return_type == 'dict':
844+
return bp, bp
845+
elif return_type == 'both':
846+
return BoxPlot.BP(ax=ax, lines=bp), bp
847+
else:
848+
return ax, bp
849+
850+
return plotf
851+
852+
853+
def _kdeplot_plotf(f, bw_method, ind):
854+
from scipy.stats import gaussian_kde
855+
from scipy import __version__ as spv
856+
857+
def plotf(ax, y, style=None, column_num=None, **kwds):
858+
y = remove_na(y)
859+
if LooseVersion(spv) >= '0.11.0':
860+
gkde = gaussian_kde(y, bw_method=bw_method)
861+
else:
862+
gkde = gaussian_kde(y)
863+
if bw_method is not None:
864+
msg = ('bw_method was added in Scipy 0.11.0.' +
865+
' Scipy version in use is %s.' % spv)
866+
warnings.warn(msg)
867+
868+
if ind is None:
869+
sample_range = max(y) - min(y)
870+
ind_local = np.linspace(min(y) - 0.5 * sample_range,
871+
max(y) + 0.5 * sample_range, 1000)
872+
else:
873+
ind_local = ind
874+
875+
y = gkde.evaluate(ind_local)
876+
lines = f(ax, ind_local, y, style=style, **kwds)
877+
return lines
878+
879+
return plotf
880+
881+
753882

754883
class MPLPlot(object):
755884
"""
@@ -1182,28 +1311,15 @@ def _is_datetype(self):
11821311
index.inferred_type in ('datetime', 'date', 'datetime64',
11831312
'time'))
11841313

1314+
def _plot_errors(self):
1315+
return any(e is not None for e in self.errors.values())
1316+
11851317
def _get_plot_function(self):
11861318
'''
11871319
Returns the matplotlib plotting function (plot or errorbar) based on
11881320
the presence of errorbar keywords.
11891321
'''
1190-
errorbar = any(e is not None for e in self.errors.values())
1191-
def plotf(ax, x, y, style=None, **kwds):
1192-
mask = com.isnull(y)
1193-
if mask.any():
1194-
y = np.ma.array(y)
1195-
y = np.ma.masked_where(mask, y)
1196-
1197-
if errorbar:
1198-
return self.plt.Axes.errorbar(ax, x, y, **kwds)
1199-
else:
1200-
# prevent style kwarg from going to errorbar, where it is unsupported
1201-
if style is not None:
1202-
args = (ax, x, y, style)
1203-
else:
1204-
args = (ax, x, y)
1205-
return self.plt.Axes.plot(*args, **kwds)
1206-
return plotf
1322+
return _mplplot_plotf(self._plot_errors())
12071323

12081324
def _get_index_name(self):
12091325
if isinstance(self.data.index, MultiIndex):
@@ -1590,7 +1706,6 @@ def _is_ts_plot(self):
15901706
return not self.x_compat and self.use_index and self._use_dynamic_x()
15911707

15921708
def _make_plot(self):
1593-
self._initialize_prior(len(self.data))
15941709

15951710
if self._is_ts_plot():
15961711
data = self._maybe_convert_index(self.data)
@@ -1622,12 +1737,13 @@ def _make_plot(self):
16221737
left, right = _get_xlim(lines)
16231738
ax.set_xlim(left, right)
16241739

1625-
def _get_stacked_values(self, y, label):
1626-
if self.stacked:
1740+
@classmethod
1741+
def _get_stacked_values(cls, ax, y, label, stacked):
1742+
if stacked:
16271743
if (y >= 0).all():
1628-
return self._pos_prior + y
1744+
return ax._pos_prior + y
16291745
elif (y <= 0).all():
1630-
return self._neg_prior + y
1746+
return ax._neg_prior + y
16311747
else:
16321748
raise ValueError('When stacked is True, each column must be either all positive or negative.'
16331749
'{0} contains both positive and negative values'.format(label))
@@ -1636,15 +1752,8 @@ def _get_stacked_values(self, y, label):
16361752

16371753
def _get_plot_function(self):
16381754
f = MPLPlot._get_plot_function(self)
1639-
def plotf(ax, x, y, style=None, column_num=None, **kwds):
1640-
# column_num is used to get the target column from protf in line and area plots
1641-
if column_num == 0:
1642-
self._initialize_prior(len(self.data))
1643-
y_values = self._get_stacked_values(y, kwds['label'])
1644-
lines = f(ax, x, y_values, style=style, **kwds)
1645-
self._update_prior(y)
1646-
return lines
1647-
return plotf
1755+
1756+
return _lineplot_plotf(f, self.stacked, self.subplots)
16481757

16491758
def _get_ts_plot_function(self):
16501759
from pandas.tseries.plotting import tsplot
@@ -1656,19 +1765,21 @@ def _plot(ax, x, data, style=None, **kwds):
16561765
return lines
16571766
return _plot
16581767

1659-
def _initialize_prior(self, n):
1660-
self._pos_prior = np.zeros(n)
1661-
self._neg_prior = np.zeros(n)
1768+
@classmethod
1769+
def _initialize_prior(cls, ax, n):
1770+
ax._pos_prior = np.zeros(n)
1771+
ax._neg_prior = np.zeros(n)
16621772

1663-
def _update_prior(self, y):
1664-
if self.stacked and not self.subplots:
1773+
@classmethod
1774+
def _update_prior(cls, ax, y, stacked, subplots):
1775+
if stacked and not subplots:
16651776
# tsplot resample may changedata length
1666-
if len(self._pos_prior) != len(y):
1667-
self._initialize_prior(len(y))
1777+
if len(ax._pos_prior) != len(y):
1778+
cls._initialize_prior(ax, len(y))
16681779
if (y >= 0).all():
1669-
self._pos_prior += y
1780+
ax._pos_prior += y
16701781
elif (y <= 0).all():
1671-
self._neg_prior += y
1782+
ax._neg_prior += y
16721783

16731784
def _maybe_convert_index(self, data):
16741785
# tsplot converts automatically, but don't want to convert index
@@ -1735,28 +1846,8 @@ def _get_plot_function(self):
17351846
raise ValueError("Log-y scales are not supported in area plot")
17361847
else:
17371848
f = MPLPlot._get_plot_function(self)
1738-
def plotf(ax, x, y, style=None, column_num=None, **kwds):
1739-
if column_num == 0:
1740-
self._initialize_prior(len(self.data))
1741-
y_values = self._get_stacked_values(y, kwds['label'])
1742-
lines = f(ax, x, y_values, style=style, **kwds)
1743-
1744-
# get data from the line to get coordinates for fill_between
1745-
xdata, y_values = lines[0].get_data(orig=False)
1746-
1747-
if (y >= 0).all():
1748-
start = self._pos_prior
1749-
elif (y <= 0).all():
1750-
start = self._neg_prior
1751-
else:
1752-
start = np.zeros(len(y))
1753-
1754-
if not 'color' in kwds:
1755-
kwds['color'] = lines[0].get_color()
17561849

1757-
self.plt.Axes.fill_between(ax, xdata, start, y_values, **kwds)
1758-
self._update_prior(y)
1759-
return lines
1850+
return _areaplot_plotf(f, self.stacked, self.subplots)
17601851

17611852
return plotf
17621853

@@ -1946,17 +2037,7 @@ def _args_adjust(self):
19462037
self.bottom = np.array(self.bottom)
19472038

19482039
def _get_plot_function(self):
1949-
def plotf(ax, y, style=None, column_num=None, **kwds):
1950-
if column_num == 0:
1951-
self._initialize_prior(len(self.bins) - 1)
1952-
y = y[~com.isnull(y)]
1953-
bottom = self._pos_prior + self.bottom
1954-
# ignore style
1955-
n, bins, patches = self.plt.Axes.hist(ax, y, bins=self.bins,
1956-
bottom=bottom, **kwds)
1957-
self._update_prior(n)
1958-
return patches
1959-
return plotf
2040+
return _histplot_plotf(self.bins, self.bottom, self.stacked, self.subplots)
19602041

19612042
def _make_plot(self):
19622043
plotf = self._get_plot_function()
@@ -2003,35 +2084,9 @@ def __init__(self, data, bw_method=None, ind=None, **kwargs):
20032084
def _args_adjust(self):
20042085
pass
20052086

2006-
def _get_ind(self, y):
2007-
if self.ind is None:
2008-
sample_range = max(y) - min(y)
2009-
ind = np.linspace(min(y) - 0.5 * sample_range,
2010-
max(y) + 0.5 * sample_range, 1000)
2011-
else:
2012-
ind = self.ind
2013-
return ind
2014-
20152087
def _get_plot_function(self):
2016-
from scipy.stats import gaussian_kde
2017-
from scipy import __version__ as spv
20182088
f = MPLPlot._get_plot_function(self)
2019-
def plotf(ax, y, style=None, column_num=None, **kwds):
2020-
y = remove_na(y)
2021-
if LooseVersion(spv) >= '0.11.0':
2022-
gkde = gaussian_kde(y, bw_method=self.bw_method)
2023-
else:
2024-
gkde = gaussian_kde(y)
2025-
if self.bw_method is not None:
2026-
msg = ('bw_method was added in Scipy 0.11.0.' +
2027-
' Scipy version in use is %s.' % spv)
2028-
warnings.warn(msg)
2029-
2030-
ind = self._get_ind(y)
2031-
y = gkde.evaluate(ind)
2032-
lines = f(ax, ind, y, style=style, **kwds)
2033-
return lines
2034-
return plotf
2089+
return _kdeplot_plotf(f, self.bw_method, self.ind)
20352090

20362091
def _post_plot_logic(self):
20372092
for ax in self.axes:
@@ -2126,24 +2181,7 @@ def _args_adjust(self):
21262181
self.sharey = False
21272182

21282183
def _get_plot_function(self):
2129-
def plotf(ax, y, column_num=None, **kwds):
2130-
if y.ndim == 2:
2131-
y = [remove_na(v) for v in y]
2132-
# Boxplot fails with empty arrays, so need to add a NaN
2133-
# if any cols are empty
2134-
# GH 8181
2135-
y = [v if v.size > 0 else np.array([np.nan]) for v in y]
2136-
else:
2137-
y = remove_na(y)
2138-
bp = ax.boxplot(y, **kwds)
2139-
2140-
if self.return_type == 'dict':
2141-
return bp, bp
2142-
elif self.return_type == 'both':
2143-
return self.BP(ax=ax, lines=bp), bp
2144-
else:
2145-
return ax, bp
2146-
return plotf
2184+
return _boxplot_plotf(self.return_type)
21472185

21482186
def _validate_color_args(self):
21492187
if 'color' in self.kwds:

0 commit comments

Comments
 (0)