From 2c492fa00b736aaf872c5d9bf5206df33954fbbf Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 30 May 2017 09:28:49 -0500 Subject: [PATCH] TST: Avoid global state in matplotlib tests Replaces most uses of implicit global state from matplotlib in test_datetimelike.py. This was potentially causing random failures where a figure expected to be on a new, blank figure would instead plot on an existing axes (that's the guess at least). --- pandas/tests/plotting/test_datetimelike.py | 379 +++++++++++---------- pandas/tests/plotting/test_series.py | 159 ++++++--- 2 files changed, 301 insertions(+), 237 deletions(-) diff --git a/pandas/tests/plotting/test_datetimelike.py b/pandas/tests/plotting/test_datetimelike.py index 0e15aaa2555f4..0cff365be3ec8 100644 --- a/pandas/tests/plotting/test_datetimelike.py +++ b/pandas/tests/plotting/test_datetimelike.py @@ -55,16 +55,15 @@ def test_ts_plot_with_tz(self): def test_fontsize_set_correctly(self): # For issue #8765 - import matplotlib.pyplot as plt # noqa df = DataFrame(np.random.randn(10, 9), index=range(10)) - ax = df.plot(fontsize=2) + fig, ax = self.plt.subplots() + df.plot(fontsize=2, ax=ax) for label in (ax.get_xticklabels() + ax.get_yticklabels()): assert label.get_fontsize() == 2 @slow def test_frame_inferred(self): # inferred freq - import matplotlib.pyplot as plt # noqa idx = date_range('1/1/1987', freq='MS', periods=100) idx = DatetimeIndex(idx.values, freq=None) @@ -90,26 +89,24 @@ def test_is_error_nozeroindex(self): _check_plot_works(a.plot, yerr=a) def test_nonnumeric_exclude(self): - import matplotlib.pyplot as plt - idx = date_range('1/1/1987', freq='A', periods=3) df = DataFrame({'A': ["x", "y", "z"], 'B': [1, 2, 3]}, idx) - ax = df.plot() # it works + fig, ax = self.plt.subplots() + df.plot(ax=ax) # it works assert len(ax.get_lines()) == 1 # B was plotted - plt.close(plt.gcf()) + self.plt.close(fig) pytest.raises(TypeError, df['A'].plot) @slow def test_tsplot(self): from pandas.tseries.plotting import tsplot - import matplotlib.pyplot as plt - ax = plt.gca() + _, ax = self.plt.subplots() ts = tm.makeTimeSeries() - f = lambda *args, **kwds: tsplot(s, plt.Axes.plot, *args, **kwds) + f = lambda *args, **kwds: tsplot(s, self.plt.Axes.plot, *args, **kwds) for s in self.period_ser: _check_plot_works(f, s.index.freq, ax=ax, series=s) @@ -123,12 +120,12 @@ def test_tsplot(self): for s in self.datetime_ser: _check_plot_works(s.plot, ax=ax) - ax = ts.plot(style='k') + _, ax = self.plt.subplots() + ts.plot(style='k', ax=ax) color = (0., 0., 0., 1) if self.mpl_ge_2_0_0 else (0., 0., 0.) assert color == ax.get_lines()[0].get_color() def test_both_style_and_color(self): - import matplotlib.pyplot as plt # noqa ts = tm.makeTimeSeries() pytest.raises(ValueError, ts.plot, style='b-', color='#000099') @@ -140,9 +137,10 @@ def test_both_style_and_color(self): def test_high_freq(self): freaks = ['ms', 'us'] for freq in freaks: + _, ax = self.plt.subplots() rng = date_range('1/1/2012', periods=100000, freq=freq) ser = Series(np.random.randn(len(rng)), rng) - _check_plot_works(ser.plot) + _check_plot_works(ser.plot, ax=ax) def test_get_datevalue(self): from pandas.plotting._converter import get_datevalue @@ -167,22 +165,25 @@ def check_format_of_first_point(ax, expected_string): annual = Series(1, index=date_range('2014-01-01', periods=3, freq='A-DEC')) - check_format_of_first_point(annual.plot(), 't = 2014 y = 1.000000') + _, ax = self.plt.subplots() + annual.plot(ax=ax) + check_format_of_first_point(ax, 't = 2014 y = 1.000000') # note this is added to the annual plot already in existence, and # changes its freq field daily = Series(1, index=date_range('2014-01-01', periods=3, freq='D')) - check_format_of_first_point(daily.plot(), + daily.plot(ax=ax) + check_format_of_first_point(ax, 't = 2014-01-01 y = 1.000000') tm.close() # tsplot - import matplotlib.pyplot as plt + _, ax = self.plt.subplots() from pandas.tseries.plotting import tsplot - tsplot(annual, plt.Axes.plot) - check_format_of_first_point(plt.gca(), 't = 2014 y = 1.000000') - tsplot(daily, plt.Axes.plot) - check_format_of_first_point(plt.gca(), 't = 2014-01-01 y = 1.000000') + tsplot(annual, self.plt.Axes.plot, ax=ax) + check_format_of_first_point(ax, 't = 2014 y = 1.000000') + tsplot(daily, self.plt.Axes.plot, ax=ax) + check_format_of_first_point(ax, 't = 2014-01-01 y = 1.000000') @slow def test_line_plot_period_series(self): @@ -215,14 +216,11 @@ def test_line_plot_inferred_freq(self): _check_plot_works(ser.plot) def test_fake_inferred_business(self): - import matplotlib.pyplot as plt - fig = plt.gcf() - plt.clf() - fig.add_subplot(111) + _, ax = self.plt.subplots() rng = date_range('2001-1-1', '2001-1-10') ts = Series(lrange(len(rng)), rng) ts = ts[:3].append(ts[5:]) - ax = ts.plot() + ts.plot(ax=ax) assert not hasattr(ax, 'freq') @slow @@ -244,15 +242,11 @@ def test_plot_multiple_inferred_freq(self): @slow def test_uhf(self): import pandas.plotting._converter as conv - import matplotlib.pyplot as plt - fig = plt.gcf() - plt.clf() - fig.add_subplot(111) - idx = date_range('2012-6-22 21:59:51.960928', freq='L', periods=500) df = DataFrame(np.random.randn(len(idx), 2), idx) - ax = df.plot() + _, ax = self.plt.subplots() + df.plot(ax=ax) axis = ax.get_xaxis() tlocs = axis.get_ticklocs() @@ -265,49 +259,40 @@ def test_uhf(self): @slow def test_irreg_hf(self): - import matplotlib.pyplot as plt - fig = plt.gcf() - plt.clf() - fig.add_subplot(111) - idx = date_range('2012-6-22 21:59:51', freq='S', periods=100) df = DataFrame(np.random.randn(len(idx), 2), idx) irreg = df.iloc[[0, 1, 3, 4]] - ax = irreg.plot() + _, ax = self.plt.subplots() + irreg.plot(ax=ax) diffs = Series(ax.get_lines()[0].get_xydata()[:, 0]).diff() sec = 1. / 24 / 60 / 60 assert (np.fabs(diffs[1:] - [sec, sec * 2, sec]) < 1e-8).all() - plt.clf() - fig.add_subplot(111) + _, ax = self.plt.subplots() df2 = df.copy() df2.index = df.index.asobject - ax = df2.plot() + df2.plot(ax=ax) diffs = Series(ax.get_lines()[0].get_xydata()[:, 0]).diff() assert (np.fabs(diffs[1:] - sec) < 1e-8).all() def test_irregular_datetime64_repr_bug(self): - import matplotlib.pyplot as plt ser = tm.makeTimeSeries() ser = ser[[0, 1, 2, 7]] - fig = plt.gcf() - plt.clf() + _, ax = self.plt.subplots() - ax = fig.add_subplot(211) - - ret = ser.plot() + ret = ser.plot(ax=ax) assert ret is not None for rs, xp in zip(ax.get_lines()[0].get_xdata(), ser.index): assert rs == xp def test_business_freq(self): - import matplotlib.pyplot as plt # noqa bts = tm.makePeriodSeries() - ax = bts.plot() + _, ax = self.plt.subplots() + bts.plot(ax=ax) assert ax.get_lines()[0].get_xydata()[0, 0] == bts.index[0].ordinal idx = ax.get_lines()[0].get_xdata() assert PeriodIndex(data=idx).freqstr == 'B' @@ -319,7 +304,8 @@ def test_business_freq_convert(self): bts = tm.makeTimeSeries().asfreq('BM') tm.N = n ts = bts.to_period('M') - ax = bts.plot() + _, ax = self.plt.subplots() + bts.plot(ax=ax) assert ax.get_lines()[0].get_xydata()[0, 0] == ts.index[0].ordinal idx = ax.get_lines()[0].get_xdata() assert PeriodIndex(data=idx).freqstr == 'M' @@ -329,19 +315,20 @@ def test_nonzero_base(self): idx = (date_range('2012-12-20', periods=24, freq='H') + timedelta( minutes=30)) df = DataFrame(np.arange(24), index=idx) - ax = df.plot() + _, ax = self.plt.subplots() + df.plot(ax=ax) rs = ax.get_lines()[0].get_xdata() assert not Index(rs).is_normalized def test_dataframe(self): bts = DataFrame({'a': tm.makeTimeSeries()}) - ax = bts.plot() + _, ax = self.plt.subplots() + bts.plot(ax=ax) idx = ax.get_lines()[0].get_xdata() tm.assert_index_equal(bts.index.to_period(), PeriodIndex(idx)) @slow def test_axis_limits(self): - import matplotlib.pyplot as plt def _test(ax): xlim = ax.get_xlim() @@ -369,14 +356,16 @@ def _test(ax): assert int(result[0]) == expected[0].ordinal assert int(result[1]) == expected[1].ordinal fig = ax.get_figure() - plt.close(fig) + self.plt.close(fig) ser = tm.makeTimeSeries() - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) _test(ax) + _, ax = self.plt.subplots() df = DataFrame({'a': ser, 'b': ser + 1}) - ax = df.plot() + df.plot(ax=ax) _test(ax) df = DataFrame({'a': ser, 'b': ser + 1}) @@ -397,13 +386,13 @@ def test_get_finder(self): @slow def test_finder_daily(self): - import matplotlib.pyplot as plt xp = Period('1999-1-1', freq='B').ordinal day_lst = [10, 40, 252, 400, 950, 2750, 10000] for n in day_lst: rng = bdate_range('1999-1-1', periods=n) ser = Series(np.random.randn(len(rng)), rng) - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) xaxis = ax.get_xaxis() rs = xaxis.get_majorticklocs()[0] assert xp == rs @@ -411,17 +400,17 @@ def test_finder_daily(self): ax.set_xlim(vmin + 0.9, vmax) rs = xaxis.get_majorticklocs()[0] assert xp == rs - plt.close(ax.get_figure()) + self.plt.close(ax.get_figure()) @slow def test_finder_quarterly(self): - import matplotlib.pyplot as plt xp = Period('1988Q1').ordinal yrs = [3.5, 11] for n in yrs: rng = period_range('1987Q2', periods=int(n * 4), freq='Q') ser = Series(np.random.randn(len(rng)), rng) - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) xaxis = ax.get_xaxis() rs = xaxis.get_majorticklocs()[0] assert rs == xp @@ -429,17 +418,17 @@ def test_finder_quarterly(self): ax.set_xlim(vmin + 0.9, vmax) rs = xaxis.get_majorticklocs()[0] assert xp == rs - plt.close(ax.get_figure()) + self.plt.close(ax.get_figure()) @slow def test_finder_monthly(self): - import matplotlib.pyplot as plt xp = Period('Jan 1988').ordinal yrs = [1.15, 2.5, 4, 11] for n in yrs: rng = period_range('1987Q2', periods=int(n * 12), freq='M') ser = Series(np.random.randn(len(rng)), rng) - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) xaxis = ax.get_xaxis() rs = xaxis.get_majorticklocs()[0] assert rs == xp @@ -447,12 +436,13 @@ def test_finder_monthly(self): ax.set_xlim(vmin + 0.9, vmax) rs = xaxis.get_majorticklocs()[0] assert xp == rs - plt.close(ax.get_figure()) + self.plt.close(ax.get_figure()) def test_finder_monthly_long(self): rng = period_range('1988Q1', periods=24 * 12, freq='M') ser = Series(np.random.randn(len(rng)), rng) - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) xaxis = ax.get_xaxis() rs = xaxis.get_majorticklocs()[0] xp = Period('1989Q1', 'M').ordinal @@ -460,23 +450,24 @@ def test_finder_monthly_long(self): @slow def test_finder_annual(self): - import matplotlib.pyplot as plt xp = [1987, 1988, 1990, 1990, 1995, 2020, 2070, 2170] for i, nyears in enumerate([5, 10, 19, 49, 99, 199, 599, 1001]): rng = period_range('1987', periods=nyears, freq='A') ser = Series(np.random.randn(len(rng)), rng) - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) xaxis = ax.get_xaxis() rs = xaxis.get_majorticklocs()[0] assert rs == Period(xp[i], freq='A').ordinal - plt.close(ax.get_figure()) + self.plt.close(ax.get_figure()) @slow def test_finder_minutely(self): nminutes = 50 * 24 * 60 rng = date_range('1/1/1999', freq='Min', periods=nminutes) ser = Series(np.random.randn(len(rng)), rng) - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) xaxis = ax.get_xaxis() rs = xaxis.get_majorticklocs()[0] xp = Period('1/1/1999', freq='Min').ordinal @@ -486,7 +477,8 @@ def test_finder_hourly(self): nhours = 23 rng = date_range('1/1/1999', freq='H', periods=nhours) ser = Series(np.random.randn(len(rng)), rng) - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) xaxis = ax.get_xaxis() rs = xaxis.get_majorticklocs()[0] xp = Period('1/1/1999', freq='H').ordinal @@ -494,11 +486,10 @@ def test_finder_hourly(self): @slow def test_gaps(self): - import matplotlib.pyplot as plt - ts = tm.makeTimeSeries() ts[5:25] = np.nan - ax = ts.plot() + _, ax = self.plt.subplots() + ts.plot(ax=ax) lines = ax.get_lines() tm._skip_if_mpl_1_5() assert len(lines) == 1 @@ -507,13 +498,14 @@ def test_gaps(self): assert isinstance(data, np.ma.core.MaskedArray) mask = data.mask assert mask[5:25, 1].all() - plt.close(ax.get_figure()) + self.plt.close(ax.get_figure()) # irregular ts = tm.makeTimeSeries() ts = ts[[0, 1, 2, 5, 7, 9, 12, 15, 20]] ts[2:5] = np.nan - ax = ts.plot() + _, ax = self.plt.subplots() + ax = ts.plot(ax=ax) lines = ax.get_lines() assert len(lines) == 1 l = lines[0] @@ -521,13 +513,14 @@ def test_gaps(self): assert isinstance(data, np.ma.core.MaskedArray) mask = data.mask assert mask[2:5, 1].all() - plt.close(ax.get_figure()) + self.plt.close(ax.get_figure()) # non-ts idx = [0, 1, 2, 5, 7, 9, 12, 15, 20] ser = Series(np.random.randn(len(idx)), idx) ser[2:5] = np.nan - ax = ser.plot() + _, ax = self.plt.subplots() + ser.plot(ax=ax) lines = ax.get_lines() assert len(lines) == 1 l = lines[0] @@ -540,7 +533,8 @@ def test_gaps(self): def test_gap_upsample(self): low = tm.makeTimeSeries() low[5:25] = np.nan - ax = low.plot() + _, ax = self.plt.subplots() + low.plot(ax=ax) idxh = date_range(low.index[0], low.index[-1], freq='12h') s = Series(np.random.randn(len(idxh)), idxh) @@ -559,26 +553,25 @@ def test_gap_upsample(self): @slow def test_secondary_y(self): - import matplotlib.pyplot as plt - ser = Series(np.random.randn(10)) ser2 = Series(np.random.randn(10)) + fig, _ = self.plt.subplots() ax = ser.plot(secondary_y=True) assert hasattr(ax, 'left_ax') assert not hasattr(ax, 'right_ax') - fig = ax.get_figure() axes = fig.get_axes() l = ax.get_lines()[0] xp = Series(l.get_ydata(), l.get_xdata()) assert_series_equal(ser, xp) assert ax.get_yaxis().get_ticks_position() == 'right' assert not axes[0].get_yaxis().get_visible() - plt.close(fig) + self.plt.close(fig) - ax2 = ser2.plot() + _, ax2 = self.plt.subplots() + ser2.plot(ax=ax2) assert (ax2.get_yaxis().get_ticks_position() == self.default_tick_position) - plt.close(ax2.get_figure()) + self.plt.close(ax2.get_figure()) ax = ser2.plot() ax2 = ser.plot(secondary_y=True) @@ -590,26 +583,26 @@ def test_secondary_y(self): @slow def test_secondary_y_ts(self): - import matplotlib.pyplot as plt idx = date_range('1/1/2000', periods=10) ser = Series(np.random.randn(10), idx) ser2 = Series(np.random.randn(10), idx) + fig, _ = self.plt.subplots() ax = ser.plot(secondary_y=True) assert hasattr(ax, 'left_ax') assert not hasattr(ax, 'right_ax') - fig = ax.get_figure() axes = fig.get_axes() l = ax.get_lines()[0] xp = Series(l.get_ydata(), l.get_xdata()).to_timestamp() assert_series_equal(ser, xp) assert ax.get_yaxis().get_ticks_position() == 'right' assert not axes[0].get_yaxis().get_visible() - plt.close(fig) + self.plt.close(fig) - ax2 = ser2.plot() + _, ax2 = self.plt.subplots() + ser2.plot(ax=ax2) assert (ax2.get_yaxis().get_ticks_position() == self.default_tick_position) - plt.close(ax2.get_figure()) + self.plt.close(ax2.get_figure()) ax = ser2.plot() ax2 = ser.plot(secondary_y=True) @@ -620,20 +613,19 @@ def test_secondary_kde(self): tm._skip_if_no_scipy() _skip_if_no_scipy_gaussian_kde() - import matplotlib.pyplot as plt # noqa ser = Series(np.random.randn(10)) - ax = ser.plot(secondary_y=True, kind='density') + fig, ax = self.plt.subplots() + ax = ser.plot(secondary_y=True, kind='density', ax=ax) assert hasattr(ax, 'left_ax') assert not hasattr(ax, 'right_ax') - fig = ax.get_figure() axes = fig.get_axes() assert axes[1].get_yaxis().get_ticks_position() == 'right' @slow def test_secondary_bar(self): ser = Series(np.random.randn(10)) - ax = ser.plot(secondary_y=True, kind='bar') - fig = ax.get_figure() + fig, ax = self.plt.subplots() + ser.plot(secondary_y=True, kind='bar', ax=ax) axes = fig.get_axes() assert axes[1].get_yaxis().get_ticks_position() == 'right' @@ -656,7 +648,7 @@ def test_secondary_bar_frame(self): assert axes[2].get_yaxis().get_ticks_position() == 'right' def test_mixed_freq_regular_first(self): - import matplotlib.pyplot as plt # noqa + # TODO s1 = tm.makeTimeSeries() s2 = s1[[0, 5, 10, 11, 12, 13, 14, 15]] @@ -676,11 +668,11 @@ def test_mixed_freq_regular_first(self): @slow def test_mixed_freq_irregular_first(self): - import matplotlib.pyplot as plt # noqa s1 = tm.makeTimeSeries() s2 = s1[[0, 5, 10, 11, 12, 13, 14, 15]] - s2.plot(style='g') - ax = s1.plot() + _, ax = self.plt.subplots() + s2.plot(style='g', ax=ax) + s1.plot(ax=ax) assert not hasattr(ax, 'freq') lines = ax.get_lines() x1 = lines[0].get_xdata() @@ -690,10 +682,10 @@ def test_mixed_freq_irregular_first(self): def test_mixed_freq_regular_first_df(self): # GH 9852 - import matplotlib.pyplot as plt # noqa s1 = tm.makeTimeSeries().to_frame() s2 = s1.iloc[[0, 5, 10, 11, 12, 13, 14, 15], :] - ax = s1.plot() + _, ax = self.plt.subplots() + s1.plot(ax=ax) ax2 = s2.plot(style='g', ax=ax) lines = ax2.get_lines() idx1 = PeriodIndex(lines[0].get_xdata()) @@ -708,11 +700,11 @@ def test_mixed_freq_regular_first_df(self): @slow def test_mixed_freq_irregular_first_df(self): # GH 9852 - import matplotlib.pyplot as plt # noqa s1 = tm.makeTimeSeries().to_frame() s2 = s1.iloc[[0, 5, 10, 11, 12, 13, 14, 15], :] - ax = s2.plot(style='g') - ax = s1.plot(ax=ax) + _, ax = self.plt.subplots() + s2.plot(style='g', ax=ax) + s1.plot(ax=ax) assert not hasattr(ax, 'freq') lines = ax.get_lines() x1 = lines[0].get_xdata() @@ -725,8 +717,9 @@ def test_mixed_freq_hf_first(self): idxl = date_range('1/1/1999', periods=12, freq='M') high = Series(np.random.randn(len(idxh)), idxh) low = Series(np.random.randn(len(idxl)), idxl) - high.plot() - ax = low.plot() + _, ax = self.plt.subplots() + high.plot(ax=ax) + low.plot(ax=ax) for l in ax.get_lines(): assert PeriodIndex(data=l.get_xdata()).freq == 'D' @@ -738,33 +731,35 @@ def test_mixed_freq_alignment(self): ts = Series(ts_data, index=ts_ind) ts2 = ts.asfreq('T').interpolate() - ax = ts.plot() - ts2.plot(style='r') + _, ax = self.plt.subplots() + ax = ts.plot(ax=ax) + ts2.plot(style='r', ax=ax) assert ax.lines[0].get_xdata()[0] == ax.lines[1].get_xdata()[0] @slow def test_mixed_freq_lf_first(self): - import matplotlib.pyplot as plt idxh = date_range('1/1/1999', periods=365, freq='D') idxl = date_range('1/1/1999', periods=12, freq='M') high = Series(np.random.randn(len(idxh)), idxh) low = Series(np.random.randn(len(idxl)), idxl) - low.plot(legend=True) - ax = high.plot(legend=True) + _, ax = self.plt.subplots() + low.plot(legend=True, ax=ax) + high.plot(legend=True, ax=ax) for l in ax.get_lines(): assert PeriodIndex(data=l.get_xdata()).freq == 'D' leg = ax.get_legend() assert len(leg.texts) == 2 - plt.close(ax.get_figure()) + self.plt.close(ax.get_figure()) idxh = date_range('1/1/1999', periods=240, freq='T') idxl = date_range('1/1/1999', periods=4, freq='H') high = Series(np.random.randn(len(idxh)), idxh) low = Series(np.random.randn(len(idxl)), idxl) - low.plot() - ax = high.plot() + _, ax = self.plt.subplots() + low.plot(ax=ax) + high.plot(ax=ax) for l in ax.get_lines(): assert PeriodIndex(data=l.get_xdata()).freq == 'T' @@ -773,8 +768,9 @@ def test_mixed_freq_irreg_period(self): irreg = ts[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 16, 17, 18, 29]] rng = period_range('1/3/2000', periods=30, freq='B') ps = Series(np.random.randn(len(rng)), rng) - irreg.plot() - ps.plot() + _, ax = self.plt.subplots() + irreg.plot(ax=ax) + ps.plot(ax=ax) def test_mixed_freq_shared_ax(self): @@ -813,9 +809,7 @@ def test_mixed_freq_shared_ax(self): def test_nat_handling(self): - fig = self.plt.gcf() - # self.plt.clf() - ax = fig.add_subplot(111) + _, ax = self.plt.subplots() dti = DatetimeIndex(['2015-01-01', NaT, '2015-01-03']) s = Series(range(len(dti)), dti) @@ -831,17 +825,18 @@ def test_to_weekly_resampling(self): idxl = date_range('1/1/1999', periods=12, freq='M') high = Series(np.random.randn(len(idxh)), idxh) low = Series(np.random.randn(len(idxl)), idxl) - high.plot() - ax = low.plot() + _, ax = self.plt.subplots() + high.plot(ax=ax) + low.plot(ax=ax) for l in ax.get_lines(): assert PeriodIndex(data=l.get_xdata()).freq == idxh.freq # tsplot from pandas.tseries.plotting import tsplot - import matplotlib.pyplot as plt - tsplot(high, plt.Axes.plot) - lines = tsplot(low, plt.Axes.plot) + _, ax = self.plt.subplots() + tsplot(high, self.plt.Axes.plot, ax=ax) + lines = tsplot(low, self.plt.Axes.plot, ax=ax) for l in lines: assert PeriodIndex(data=l.get_xdata()).freq == idxh.freq @@ -851,8 +846,9 @@ def test_from_weekly_resampling(self): idxl = date_range('1/1/1999', periods=12, freq='M') high = Series(np.random.randn(len(idxh)), idxh) low = Series(np.random.randn(len(idxl)), idxl) - low.plot() - ax = high.plot() + _, ax = self.plt.subplots() + low.plot(ax=ax) + high.plot(ax=ax) expected_h = idxh.to_period().asi8.astype(np.float64) expected_l = np.array([1514, 1519, 1523, 1527, 1531, 1536, 1540, 1544, @@ -868,10 +864,10 @@ def test_from_weekly_resampling(self): # tsplot from pandas.tseries.plotting import tsplot - import matplotlib.pyplot as plt - tsplot(low, plt.Axes.plot) - lines = tsplot(high, plt.Axes.plot) + _, ax = self.plt.subplots() + tsplot(low, self.plt.Axes.plot, ax=ax) + lines = tsplot(high, self.plt.Axes.plot, ax=ax) for l in lines: assert PeriodIndex(data=l.get_xdata()).freq == idxh.freq xdata = l.get_xdata(orig=False) @@ -891,8 +887,9 @@ def test_from_resampling_area_line_mixed(self): # low to high for kind1, kind2 in [('line', 'area'), ('area', 'line')]: - ax = low.plot(kind=kind1, stacked=True) - ax = high.plot(kind=kind2, stacked=True, ax=ax) + _, ax = self.plt.subplots() + low.plot(kind=kind1, stacked=True, ax=ax) + high.plot(kind=kind2, stacked=True, ax=ax) # check low dataframe result expected_x = np.array([1514, 1519, 1523, 1527, 1531, 1536, 1540, @@ -923,8 +920,9 @@ def test_from_resampling_area_line_mixed(self): # high to low for kind1, kind2 in [('line', 'area'), ('area', 'line')]: - ax = high.plot(kind=kind1, stacked=True) - ax = low.plot(kind=kind2, stacked=True, ax=ax) + _, ax = self.plt.subplots() + high.plot(kind=kind1, stacked=True, ax=ax) + low.plot(kind=kind2, stacked=True, ax=ax) # check high dataframe result expected_x = idxh.to_period().asi8.astype(np.float64) @@ -960,16 +958,18 @@ def test_mixed_freq_second_millisecond(self): high = Series(np.random.randn(len(idxh)), idxh) low = Series(np.random.randn(len(idxl)), idxl) # high to low - high.plot() - ax = low.plot() + _, ax = self.plt.subplots() + high.plot(ax=ax) + low.plot(ax=ax) assert len(ax.get_lines()) == 2 for l in ax.get_lines(): assert PeriodIndex(data=l.get_xdata()).freq == 'L' tm.close() # low to high - low.plot() - ax = high.plot() + _, ax = self.plt.subplots() + low.plot(ax=ax) + high.plot(ax=ax) assert len(ax.get_lines()) == 2 for l in ax.get_lines(): assert PeriodIndex(data=l.get_xdata()).freq == 'L' @@ -985,7 +985,8 @@ def test_irreg_dtypes(self): idx = date_range('1/1/2000', periods=10) idx = idx[[0, 2, 5, 9]].asobject df = DataFrame(np.random.randn(len(idx), 3), idx) - _check_plot_works(df.plot) + _, ax = self.plt.subplots() + _check_plot_works(df.plot, ax=ax) @slow def test_time(self): @@ -995,7 +996,8 @@ def test_time(self): df = DataFrame({'a': np.random.randn(len(ts)), 'b': np.random.randn(len(ts))}, index=ts) - ax = df.plot() + _, ax = self.plt.subplots() + df.plot(ax=ax) # verify tick labels ticks = ax.get_xticks() @@ -1031,7 +1033,8 @@ def test_time_musec(self): df = DataFrame({'a': np.random.randn(len(ts)), 'b': np.random.randn(len(ts))}, index=ts) - ax = df.plot() + _, ax = self.plt.subplots() + ax = df.plot(ax=ax) # verify tick labels ticks = ax.get_xticks() @@ -1054,8 +1057,9 @@ def test_secondary_upsample(self): idxl = date_range('1/1/1999', periods=12, freq='M') high = Series(np.random.randn(len(idxh)), idxh) low = Series(np.random.randn(len(idxl)), idxl) - low.plot() - ax = high.plot(secondary_y=True) + _, ax = self.plt.subplots() + low.plot(ax=ax) + ax = high.plot(secondary_y=True, ax=ax) for l in ax.get_lines(): assert PeriodIndex(l.get_xdata()).freq == 'D' assert hasattr(ax, 'left_ax') @@ -1065,14 +1069,12 @@ def test_secondary_upsample(self): @slow def test_secondary_legend(self): - import matplotlib.pyplot as plt - fig = plt.gcf() - plt.clf() + fig = self.plt.figure() ax = fig.add_subplot(211) # ts df = tm.makeTimeDataFrame() - ax = df.plot(secondary_y=['A', 'B']) + df.plot(secondary_y=['A', 'B'], ax=ax) leg = ax.get_legend() assert len(leg.get_lines()) == 4 assert leg.get_texts()[0].get_text() == 'A (right)' @@ -1086,33 +1088,37 @@ def test_secondary_legend(self): # TODO: color cycle problems assert len(colors) == 4 + self.plt.close(fig) - plt.clf() + fig = self.plt.figure() ax = fig.add_subplot(211) - ax = df.plot(secondary_y=['A', 'C'], mark_right=False) + df.plot(secondary_y=['A', 'C'], mark_right=False, ax=ax) leg = ax.get_legend() assert len(leg.get_lines()) == 4 assert leg.get_texts()[0].get_text() == 'A' assert leg.get_texts()[1].get_text() == 'B' assert leg.get_texts()[2].get_text() == 'C' assert leg.get_texts()[3].get_text() == 'D' + self.plt.close(fig) - plt.clf() - ax = df.plot(kind='bar', secondary_y=['A']) + fig, ax = self.plt.subplots() + df.plot(kind='bar', secondary_y=['A'], ax=ax) leg = ax.get_legend() assert leg.get_texts()[0].get_text() == 'A (right)' assert leg.get_texts()[1].get_text() == 'B' + self.plt.close(fig) - plt.clf() - ax = df.plot(kind='bar', secondary_y=['A'], mark_right=False) + fig, ax = self.plt.subplots() + df.plot(kind='bar', secondary_y=['A'], mark_right=False, ax=ax) leg = ax.get_legend() assert leg.get_texts()[0].get_text() == 'A' assert leg.get_texts()[1].get_text() == 'B' + self.plt.close(fig) - plt.clf() + fig = self.plt.figure() ax = fig.add_subplot(211) df = tm.makeTimeDataFrame() - ax = df.plot(secondary_y=['C', 'D']) + ax = df.plot(secondary_y=['C', 'D'], ax=ax) leg = ax.get_legend() assert len(leg.get_lines()) == 4 assert ax.right_ax.get_legend() is None @@ -1122,12 +1128,13 @@ def test_secondary_legend(self): # TODO: color cycle problems assert len(colors) == 4 + self.plt.close(fig) # non-ts df = tm.makeDataFrame() - plt.clf() + fig = self.plt.figure() ax = fig.add_subplot(211) - ax = df.plot(secondary_y=['A', 'B']) + ax = df.plot(secondary_y=['A', 'B'], ax=ax) leg = ax.get_legend() assert len(leg.get_lines()) == 4 assert ax.right_ax.get_legend() is None @@ -1137,10 +1144,11 @@ def test_secondary_legend(self): # TODO: color cycle problems assert len(colors) == 4 + self.plt.close() - plt.clf() + fig = self.plt.figure() ax = fig.add_subplot(211) - ax = df.plot(secondary_y=['C', 'D']) + ax = df.plot(secondary_y=['C', 'D'], ax=ax) leg = ax.get_legend() assert len(leg.get_lines()) == 4 assert ax.right_ax.get_legend() is None @@ -1154,7 +1162,8 @@ def test_secondary_legend(self): def test_format_date_axis(self): rng = date_range('1/1/2012', periods=12, freq='M') df = DataFrame(np.random.randn(len(rng), 3), rng) - ax = df.plot() + _, ax = self.plt.subplots() + ax = df.plot(ax=ax) xaxis = ax.get_xaxis() for l in xaxis.get_ticklabels(): if len(l.get_text()) > 0: @@ -1162,28 +1171,21 @@ def test_format_date_axis(self): @slow def test_ax_plot(self): - import matplotlib.pyplot as plt - x = DatetimeIndex(start='2012-01-02', periods=10, freq='D') y = lrange(len(x)) - fig = plt.figure() - ax = fig.add_subplot(111) + _, ax = self.plt.subplots() lines = ax.plot(x, y, label='Y') tm.assert_index_equal(DatetimeIndex(lines[0].get_xdata()), x) @slow def test_mpl_nopandas(self): - import matplotlib.pyplot as plt - dates = [date(2008, 12, 31), date(2009, 1, 31)] values1 = np.arange(10.0, 11.0, 0.5) values2 = np.arange(11.0, 12.0, 0.5) kw = dict(fmt='-', lw=4) - plt.close('all') - fig = plt.figure() - ax = fig.add_subplot(111) + _, ax = self.plt.subplots() ax.plot_date([x.toordinal() for x in dates], values1, **kw) ax.plot_date([x.toordinal() for x in dates], values2, **kw) @@ -1201,7 +1203,8 @@ def test_irregular_ts_shared_ax_xlim(self): ts_irregular = ts[[1, 4, 5, 6, 8, 9, 10, 12, 13, 14, 15, 17, 18]] # plot the left section of the irregular series, then the right section - ax = ts_irregular[:5].plot() + _, ax = self.plt.subplots() + ts_irregular[:5].plot(ax=ax) ts_irregular[5:].plot(ax=ax) # check that axis limits are correct @@ -1217,7 +1220,8 @@ def test_secondary_y_non_ts_xlim(self): s1 = Series(1, index=index_1) s2 = Series(2, index=index_2) - ax = s1.plot() + _, ax = self.plt.subplots() + s1.plot(ax=ax) left_before, right_before = ax.get_xlim() s2.plot(secondary_y=True, ax=ax) left_after, right_after = ax.get_xlim() @@ -1233,7 +1237,8 @@ def test_secondary_y_regular_ts_xlim(self): s1 = Series(1, index=index_1) s2 = Series(2, index=index_2) - ax = s1.plot() + _, ax = self.plt.subplots() + s1.plot(ax=ax) left_before, right_before = ax.get_xlim() s2.plot(secondary_y=True, ax=ax) left_after, right_after = ax.get_xlim() @@ -1247,7 +1252,8 @@ def test_secondary_y_mixed_freq_ts_xlim(self): rng = date_range('2000-01-01', periods=10000, freq='min') ts = Series(1, index=rng) - ax = ts.plot() + _, ax = self.plt.subplots() + ts.plot(ax=ax) left_before, right_before = ax.get_xlim() ts.resample('D').mean().plot(secondary_y=True, ax=ax) left_after, right_after = ax.get_xlim() @@ -1262,7 +1268,8 @@ def test_secondary_y_irregular_ts_xlim(self): ts = tm.makeTimeSeries()[:20] ts_irregular = ts[[1, 4, 5, 6, 8, 9, 10, 12, 13, 14, 15, 17, 18]] - ax = ts_irregular[:5].plot() + _, ax = self.plt.subplots() + ts_irregular[:5].plot(ax=ax) # plot higher-x values on secondary axis ts_irregular[5:].plot(secondary_y=True, ax=ax) # ensure secondary limits aren't overwritten by plot on primary @@ -1275,10 +1282,11 @@ def test_secondary_y_irregular_ts_xlim(self): def test_plot_outofbounds_datetime(self): # 2579 - checking this does not raise values = [date(1677, 1, 1), date(1677, 1, 2)] - self.plt.plot(values) + _, ax = self.plt.subplots() + ax.plot(values) values = [datetime(1677, 1, 1, 12), datetime(1677, 1, 2, 12)] - self.plt.plot(values) + ax.plot(values) def test_format_timedelta_ticks_narrow(self): if is_platform_mac(): @@ -1290,8 +1298,8 @@ def test_format_timedelta_ticks_narrow(self): rng = timedelta_range('0', periods=10, freq='ns') df = DataFrame(np.random.randn(len(rng), 3), rng) - ax = df.plot(fontsize=2) - fig = ax.get_figure() + fig, ax = self.plt.subplots() + df.plot(fontsize=2, ax=ax) fig.canvas.draw() labels = ax.get_xticklabels() assert len(labels) == len(expected_labels) @@ -1316,8 +1324,8 @@ def test_format_timedelta_ticks_wide(self): rng = timedelta_range('0', periods=10, freq='1 d') df = DataFrame(np.random.randn(len(rng), 3), rng) - ax = df.plot(fontsize=2) - fig = ax.get_figure() + fig, ax = self.plt.subplots() + ax = df.plot(fontsize=2, ax=ax) fig.canvas.draw() labels = ax.get_xticklabels() assert len(labels) == len(expected_labels) @@ -1327,19 +1335,22 @@ def test_format_timedelta_ticks_wide(self): def test_timedelta_plot(self): # test issue #8711 s = Series(range(5), timedelta_range('1day', periods=5)) - _check_plot_works(s.plot) + _, ax = self.plt.subplots() + _check_plot_works(s.plot, ax=ax) # test long period index = timedelta_range('1 day 2 hr 30 min 10 s', periods=10, freq='1 d') s = Series(np.random.randn(len(index)), index) - _check_plot_works(s.plot) + _, ax = self.plt.subplots() + _check_plot_works(s.plot, ax=ax) # test short period index = timedelta_range('1 day 2 hr 30 min 10 s', periods=10, freq='1 ns') s = Series(np.random.randn(len(index)), index) - _check_plot_works(s.plot) + _, ax = self.plt.subplots() + _check_plot_works(s.plot, ax=ax) def test_hist(self): # https://github.com/matplotlib/matplotlib/issues/8459 @@ -1347,7 +1358,8 @@ def test_hist(self): x = rng w1 = np.arange(0, 1, .1) w2 = np.arange(0, 1, .1)[::-1] - self.plt.hist([x, x], weights=[w1, w2]) + _, ax = self.plt.subplots() + ax.hist([x, x], weights=[w1, w2]) @slow def test_overlapping_datetime(self): @@ -1361,7 +1373,8 @@ def test_overlapping_datetime(self): # plot first series, then add the second series to those axes, # then try adding the first series again - ax = s1.plot() + _, ax = self.plt.subplots() + s1.plot(ax=ax) s2.plot(ax=ax) s1.plot(ax=ax) diff --git a/pandas/tests/plotting/test_series.py b/pandas/tests/plotting/test_series.py index 340a98484480f..7c66b5dafb9c7 100644 --- a/pandas/tests/plotting/test_series.py +++ b/pandas/tests/plotting/test_series.py @@ -82,7 +82,8 @@ def test_plot(self): @slow def test_plot_figsize_and_title(self): # figsize and title - ax = self.series.plot(title='Test', figsize=(16, 8)) + _, ax = self.plt.subplots() + ax = self.series.plot(title='Test', figsize=(16, 8), ax=ax) self._check_text_labels(ax.title, 'Test') self._check_axes_shape(ax, axes_num=1, layout=(1, 1), figsize=(16, 8)) @@ -93,25 +94,28 @@ def test_dont_modify_rcParams(self): else: key = 'axes.color_cycle' colors = self.plt.rcParams[key] - Series([1, 2, 3]).plot() + _, ax = self.plt.subplots() + Series([1, 2, 3]).plot(ax=ax) assert colors == self.plt.rcParams[key] def test_ts_line_lim(self): - ax = self.ts.plot() + fig, ax = self.plt.subplots() + ax = self.ts.plot(ax=ax) xmin, xmax = ax.get_xlim() lines = ax.get_lines() assert xmin == lines[0].get_data(orig=False)[0][0] assert xmax == lines[0].get_data(orig=False)[0][-1] tm.close() - ax = self.ts.plot(secondary_y=True) + ax = self.ts.plot(secondary_y=True, ax=ax) xmin, xmax = ax.get_xlim() lines = ax.get_lines() assert xmin == lines[0].get_data(orig=False)[0][0] assert xmax == lines[0].get_data(orig=False)[0][-1] def test_ts_area_lim(self): - ax = self.ts.plot.area(stacked=False) + _, ax = self.plt.subplots() + ax = self.ts.plot.area(stacked=False, ax=ax) xmin, xmax = ax.get_xlim() line = ax.get_lines()[0].get_data(orig=False)[0] assert xmin == line[0] @@ -119,7 +123,8 @@ def test_ts_area_lim(self): tm.close() # GH 7471 - ax = self.ts.plot.area(stacked=False, x_compat=True) + _, ax = self.plt.subplots() + ax = self.ts.plot.area(stacked=False, x_compat=True, ax=ax) xmin, xmax = ax.get_xlim() line = ax.get_lines()[0].get_data(orig=False)[0] assert xmin == line[0] @@ -128,14 +133,16 @@ def test_ts_area_lim(self): tz_ts = self.ts.copy() tz_ts.index = tz_ts.tz_localize('GMT').tz_convert('CET') - ax = tz_ts.plot.area(stacked=False, x_compat=True) + _, ax = self.plt.subplots() + ax = tz_ts.plot.area(stacked=False, x_compat=True, ax=ax) xmin, xmax = ax.get_xlim() line = ax.get_lines()[0].get_data(orig=False)[0] assert xmin == line[0] assert xmax == line[-1] tm.close() - ax = tz_ts.plot.area(stacked=False, secondary_y=True) + _, ax = self.plt.subplots() + ax = tz_ts.plot.area(stacked=False, secondary_y=True, ax=ax) xmin, xmax = ax.get_xlim() line = ax.get_lines()[0].get_data(orig=False)[0] assert xmin == line[0] @@ -143,23 +150,28 @@ def test_ts_area_lim(self): def test_label(self): s = Series([1, 2]) - ax = s.plot(label='LABEL', legend=True) + _, ax = self.plt.subplots() + ax = s.plot(label='LABEL', legend=True, ax=ax) self._check_legend_labels(ax, labels=['LABEL']) self.plt.close() - ax = s.plot(legend=True) + _, ax = self.plt.subplots() + ax = s.plot(legend=True, ax=ax) self._check_legend_labels(ax, labels=['None']) self.plt.close() # get name from index s.name = 'NAME' - ax = s.plot(legend=True) + _, ax = self.plt.subplots() + ax = s.plot(legend=True, ax=ax) self._check_legend_labels(ax, labels=['NAME']) self.plt.close() # override the default - ax = s.plot(legend=True, label='LABEL') + _, ax = self.plt.subplots() + ax = s.plot(legend=True, label='LABEL', ax=ax) self._check_legend_labels(ax, labels=['LABEL']) self.plt.close() # Add lebel info, but don't draw - ax = s.plot(legend=False, label='LABEL') + _, ax = self.plt.subplots() + ax = s.plot(legend=False, label='LABEL', ax=ax) assert ax.get_legend() is None # Hasn't been drawn ax.legend() # draw it self._check_legend_labels(ax, labels=['LABEL']) @@ -189,10 +201,12 @@ def test_line_area_nan_series(self): def test_line_use_index_false(self): s = Series([1, 2, 3], index=['a', 'b', 'c']) s.index.name = 'The Index' - ax = s.plot(use_index=False) + _, ax = self.plt.subplots() + ax = s.plot(use_index=False, ax=ax) label = ax.get_xlabel() assert label == '' - ax2 = s.plot.bar(use_index=False) + _, ax = self.plt.subplots() + ax2 = s.plot.bar(use_index=False, ax=ax) label2 = ax2.get_xlabel() assert label2 == '' @@ -203,11 +217,13 @@ def test_bar_log(self): if not self.mpl_le_1_2_1: expected = np.hstack((.1, expected, 1e4)) - ax = Series([200, 500]).plot.bar(log=True) + _, ax = self.plt.subplots() + ax = Series([200, 500]).plot.bar(log=True, ax=ax) tm.assert_numpy_array_equal(ax.yaxis.get_ticklocs(), expected) tm.close() - ax = Series([200, 500]).plot.barh(log=True) + _, ax = self.plt.subplots() + ax = Series([200, 500]).plot.barh(log=True, ax=ax) tm.assert_numpy_array_equal(ax.xaxis.get_ticklocs(), expected) tm.close() @@ -219,7 +235,8 @@ def test_bar_log(self): if self.mpl_ge_2_0_0: expected = np.hstack((1.0e-05, expected)) - ax = Series([0.1, 0.01, 0.001]).plot(log=True, kind='bar') + _, ax = self.plt.subplots() + ax = Series([0.1, 0.01, 0.001]).plot(log=True, kind='bar', ax=ax) ymin = 0.0007943282347242822 if self.mpl_ge_2_0_0 else 0.001 ymax = 0.12589254117941673 if self.mpl_ge_2_0_0 else .10000000000000001 res = ax.get_ylim() @@ -228,7 +245,8 @@ def test_bar_log(self): tm.assert_numpy_array_equal(ax.yaxis.get_ticklocs(), expected) tm.close() - ax = Series([0.1, 0.01, 0.001]).plot(log=True, kind='barh') + _, ax = self.plt.subplots() + ax = Series([0.1, 0.01, 0.001]).plot(log=True, kind='barh', ax=ax) res = ax.get_xlim() tm.assert_almost_equal(res[0], ymin) tm.assert_almost_equal(res[1], ymax) @@ -237,23 +255,27 @@ def test_bar_log(self): @slow def test_bar_ignore_index(self): df = Series([1, 2, 3, 4], index=['a', 'b', 'c', 'd']) - ax = df.plot.bar(use_index=False) + _, ax = self.plt.subplots() + ax = df.plot.bar(use_index=False, ax=ax) self._check_text_labels(ax.get_xticklabels(), ['0', '1', '2', '3']) def test_rotation(self): df = DataFrame(randn(5, 5)) # Default rot 0 - axes = df.plot() + _, ax = self.plt.subplots() + axes = df.plot(ax=ax) self._check_ticks_props(axes, xrot=0) - axes = df.plot(rot=30) + _, ax = self.plt.subplots() + axes = df.plot(rot=30, ax=ax) self._check_ticks_props(axes, xrot=30) def test_irregular_datetime(self): rng = date_range('1/1/2000', '3/1/2000') rng = rng[[0, 1, 2, 3, 5, 9, 10, 11, 12]] ser = Series(randn(len(rng)), rng) - ax = ser.plot() + _, ax = self.plt.subplots() + ax = ser.plot(ax=ax) xp = datetime(1999, 1, 1).toordinal() ax.set_xlim('1/1/1999', '1/1/2001') assert xp == ax.get_xlim()[0] @@ -311,7 +333,8 @@ def test_pie_series(self): def test_pie_nan(self): s = Series([1, np.nan, 1, 1]) - ax = s.plot.pie(legend=True) + _, ax = self.plt.subplots() + ax = s.plot.pie(legend=True, ax=ax) expected = ['0', '', '2', '3'] result = [x.get_text() for x in ax.texts] assert result == expected @@ -319,7 +342,8 @@ def test_pie_nan(self): @slow def test_hist_df_kwargs(self): df = DataFrame(np.random.randn(10, 2)) - ax = df.plot.hist(bins=5) + _, ax = self.plt.subplots() + ax = df.plot.hist(bins=5, ax=ax) assert len(ax.patches) == 10 @slow @@ -329,10 +353,12 @@ def test_hist_df_with_nonnumerics(self): df = DataFrame( np.random.randn(10, 4), columns=['A', 'B', 'C', 'D']) df['E'] = ['x', 'y'] * 5 - ax = df.plot.hist(bins=5) + _, ax = self.plt.subplots() + ax = df.plot.hist(bins=5, ax=ax) assert len(ax.patches) == 20 - ax = df.plot.hist() # bins=10 + _, ax = self.plt.subplots() + ax = df.plot.hist(ax=ax) # bins=10 assert len(ax.patches) == 40 @slow @@ -439,7 +465,8 @@ def test_hist_secondary_legend(self): df = DataFrame(np.random.randn(30, 4), columns=list('abcd')) # primary -> secondary - ax = df['a'].plot.hist(legend=True) + _, ax = self.plt.subplots() + ax = df['a'].plot.hist(legend=True, ax=ax) df['b'].plot.hist(ax=ax, legend=True, secondary_y=True) # both legends are dran on left ax # left and right axis must be visible @@ -449,7 +476,8 @@ def test_hist_secondary_legend(self): tm.close() # secondary -> secondary - ax = df['a'].plot.hist(legend=True, secondary_y=True) + _, ax = self.plt.subplots() + ax = df['a'].plot.hist(legend=True, secondary_y=True, ax=ax) df['b'].plot.hist(ax=ax, legend=True, secondary_y=True) # both legends are draw on left ax # left axis must be invisible, right axis must be visible @@ -460,7 +488,8 @@ def test_hist_secondary_legend(self): tm.close() # secondary -> primary - ax = df['a'].plot.hist(legend=True, secondary_y=True) + _, ax = self.plt.subplots() + ax = df['a'].plot.hist(legend=True, secondary_y=True, ax=ax) # right axes is returned df['b'].plot.hist(ax=ax, legend=True) # both legends are draw on left ax @@ -477,8 +506,9 @@ def test_df_series_secondary_legend(self): s = Series(np.random.randn(30), name='x') # primary -> secondary (without passing ax) - ax = df.plot() - s.plot(legend=True, secondary_y=True) + _, ax = self.plt.subplots() + ax = df.plot(ax=ax) + s.plot(legend=True, secondary_y=True, ax=ax) # both legends are dran on left ax # left and right axis must be visible self._check_legend_labels(ax, labels=['a', 'b', 'c', 'x (right)']) @@ -487,7 +517,8 @@ def test_df_series_secondary_legend(self): tm.close() # primary -> secondary (with passing ax) - ax = df.plot() + _, ax = self.plt.subplots() + ax = df.plot(ax=ax) s.plot(ax=ax, legend=True, secondary_y=True) # both legends are dran on left ax # left and right axis must be visible @@ -497,8 +528,9 @@ def test_df_series_secondary_legend(self): tm.close() # seconcary -> secondary (without passing ax) - ax = df.plot(secondary_y=True) - s.plot(legend=True, secondary_y=True) + _, ax = self.plt.subplots() + ax = df.plot(secondary_y=True, ax=ax) + s.plot(legend=True, secondary_y=True, ax=ax) # both legends are dran on left ax # left axis must be invisible and right axis must be visible expected = ['a (right)', 'b (right)', 'c (right)', 'x (right)'] @@ -508,7 +540,8 @@ def test_df_series_secondary_legend(self): tm.close() # secondary -> secondary (with passing ax) - ax = df.plot(secondary_y=True) + _, ax = self.plt.subplots() + ax = df.plot(secondary_y=True, ax=ax) s.plot(ax=ax, legend=True, secondary_y=True) # both legends are dran on left ax # left axis must be invisible and right axis must be visible @@ -519,7 +552,8 @@ def test_df_series_secondary_legend(self): tm.close() # secondary -> secondary (with passing ax) - ax = df.plot(secondary_y=True, mark_right=False) + _, ax = self.plt.subplots() + ax = df.plot(secondary_y=True, mark_right=False, ax=ax) s.plot(ax=ax, legend=True, secondary_y=True) # both legends are dran on left ax # left axis must be invisible and right axis must be visible @@ -533,11 +567,13 @@ def test_df_series_secondary_legend(self): def test_plot_fails_with_dupe_color_and_style(self): x = Series(randn(2)) with pytest.raises(ValueError): - x.plot(style='k--', color='k') + _, ax = self.plt.subplots() + x.plot(style='k--', color='k', ax=ax) @slow def test_hist_kde(self): - ax = self.ts.plot.hist(logy=True) + _, ax = self.plt.subplots() + ax = self.ts.plot.hist(logy=True, ax=ax) self._check_ax_scales(ax, yaxis='log') xlabels = ax.get_xticklabels() # ticks are values, thus ticklabels are blank @@ -549,7 +585,8 @@ def test_hist_kde(self): _skip_if_no_scipy_gaussian_kde() _check_plot_works(self.ts.plot.kde) _check_plot_works(self.ts.plot.density) - ax = self.ts.plot.kde(logy=True) + _, ax = self.plt.subplots() + ax = self.ts.plot.kde(logy=True, ax=ax) self._check_ax_scales(ax, yaxis='log') xlabels = ax.get_xticklabels() self._check_text_labels(xlabels, [''] * len(xlabels)) @@ -565,8 +602,9 @@ def test_kde_kwargs(self): ind=linspace(-100, 100, 20)) _check_plot_works(self.ts.plot.density, bw_method=.5, ind=linspace(-100, 100, 20)) + _, ax = self.plt.subplots() ax = self.ts.plot.kde(logy=True, bw_method=.5, - ind=linspace(-100, 100, 20)) + ind=linspace(-100, 100, 20), ax=ax) self._check_ax_scales(ax, yaxis='log') self._check_text_labels(ax.yaxis.get_label(), 'Density') @@ -583,29 +621,34 @@ def test_kde_missing_vals(self): @slow def test_hist_kwargs(self): - ax = self.ts.plot.hist(bins=5) + _, ax = self.plt.subplots() + ax = self.ts.plot.hist(bins=5, ax=ax) assert len(ax.patches) == 5 self._check_text_labels(ax.yaxis.get_label(), 'Frequency') tm.close() if self.mpl_ge_1_3_1: - ax = self.ts.plot.hist(orientation='horizontal') + _, ax = self.plt.subplots() + ax = self.ts.plot.hist(orientation='horizontal', ax=ax) self._check_text_labels(ax.xaxis.get_label(), 'Frequency') tm.close() - ax = self.ts.plot.hist(align='left', stacked=True) + _, ax = self.plt.subplots() + ax = self.ts.plot.hist(align='left', stacked=True, ax=ax) tm.close() @slow def test_hist_kde_color(self): - ax = self.ts.plot.hist(logy=True, bins=10, color='b') + _, ax = self.plt.subplots() + ax = self.ts.plot.hist(logy=True, bins=10, color='b', ax=ax) self._check_ax_scales(ax, yaxis='log') assert len(ax.patches) == 10 self._check_colors(ax.patches, facecolors=['b'] * 10) tm._skip_if_no_scipy() _skip_if_no_scipy_gaussian_kde() - ax = self.ts.plot.kde(logy=True, color='r') + _, ax = self.plt.subplots() + ax = self.ts.plot.kde(logy=True, color='r', ax=ax) self._check_ax_scales(ax, yaxis='log') lines = ax.get_lines() assert len(lines) == 1 @@ -613,7 +656,8 @@ def test_hist_kde_color(self): @slow def test_boxplot_series(self): - ax = self.ts.plot.box(logy=True) + _, ax = self.plt.subplots() + ax = self.ts.plot.box(logy=True, ax=ax) self._check_ax_scales(ax, yaxis='log') xlabels = ax.get_xticklabels() self._check_text_labels(xlabels, [self.ts.name]) @@ -625,20 +669,22 @@ def test_kind_both_ways(self): s = Series(range(3)) kinds = (plotting._core._common_kinds + plotting._core._series_kinds) + _, ax = self.plt.subplots() for kind in kinds: if not _ok_for_gaussian_kde(kind): continue - s.plot(kind=kind) + s.plot(kind=kind, ax=ax) getattr(s.plot, kind)() @slow def test_invalid_plot_data(self): s = Series(list('abcd')) + _, ax = self.plt.subplots() for kind in plotting._core._common_kinds: if not _ok_for_gaussian_kde(kind): continue with pytest.raises(TypeError): - s.plot(kind=kind) + s.plot(kind=kind, ax=ax) @slow def test_valid_object_plot(self): @@ -650,11 +696,12 @@ def test_valid_object_plot(self): def test_partially_invalid_plot_data(self): s = Series(['a', 'b', 1.0, 2]) + _, ax = self.plt.subplots() for kind in plotting._core._common_kinds: if not _ok_for_gaussian_kde(kind): continue with pytest.raises(TypeError): - s.plot(kind=kind) + s.plot(kind=kind, ax=ax) def test_invalid_kind(self): s = Series([1, 2]) @@ -776,13 +823,15 @@ def test_standard_colors_all(self): def test_series_plot_color_kwargs(self): # GH1890 - ax = Series(np.arange(12) + 1).plot(color='green') + _, ax = self.plt.subplots() + ax = Series(np.arange(12) + 1).plot(color='green', ax=ax) self._check_colors(ax.get_lines(), linecolors=['green']) def test_time_series_plot_color_kwargs(self): # #1890 + _, ax = self.plt.subplots() ax = Series(np.arange(12) + 1, index=date_range( - '1/1/2000', periods=12)).plot(color='green') + '1/1/2000', periods=12)).plot(color='green', ax=ax) self._check_colors(ax.get_lines(), linecolors=['green']) def test_time_series_plot_color_with_empty_kwargs(self): @@ -797,14 +846,16 @@ def test_time_series_plot_color_with_empty_kwargs(self): ncolors = 3 + _, ax = self.plt.subplots() for i in range(ncolors): - ax = s.plot() + ax = s.plot(ax=ax) self._check_colors(ax.get_lines(), linecolors=def_colors[:ncolors]) def test_xticklabels(self): # GH11529 s = Series(np.arange(10), index=['P%02d' % i for i in range(10)]) - ax = s.plot(xticks=[0, 3, 5, 9]) + _, ax = self.plt.subplots() + ax = s.plot(xticks=[0, 3, 5, 9], ax=ax) exp = ['P%02d' % i for i in [0, 3, 5, 9]] self._check_text_labels(ax.get_xticklabels(), exp)