diff --git a/pandas/tests/plotting/common.py b/pandas/tests/plotting/common.py index 82a62c4588b94..c868c8d4fba07 100644 --- a/pandas/tests/plotting/common.py +++ b/pandas/tests/plotting/common.py @@ -550,41 +550,85 @@ def _unpack_cycler(self, rcParams, field="color"): return [v[field] for v in rcParams["axes.prop_cycle"]] -def _check_plot_works(f, filterwarnings="always", **kwargs): +def _check_plot_works(f, filterwarnings="always", default_axes=False, **kwargs): + """ + Create plot and ensure that plot return object is valid. + + Parameters + ---------- + f : func + Plotting function. + filterwarnings : str + Warnings filter. + See https://docs.python.org/3/library/warnings.html#warning-filter + default_axes : bool, optional + If False (default): + - If `ax` not in `kwargs`, then create subplot(211) and plot there + - Create new subplot(212) and plot there as well + - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`) + If True: + - Simply run plotting function with kwargs provided + - All required axes instances will be created automatically + - It is recommended to use it when the plotting function + creates multiple axes itself. It helps avoid warnings like + 'UserWarning: To output multiple subplots, + the figure containing the passed axes is being cleared' + **kwargs + Keyword arguments passed to the plotting function. + + Returns + ------- + Plot object returned by the last plotting. + """ import matplotlib.pyplot as plt + if default_axes: + gen_plots = _gen_default_plot + else: + gen_plots = _gen_two_subplots + ret = None with warnings.catch_warnings(): warnings.simplefilter(filterwarnings) try: - try: - fig = kwargs["figure"] - except KeyError: - fig = plt.gcf() - + fig = kwargs.get("figure", plt.gcf()) plt.clf() - kwargs.get("ax", fig.add_subplot(211)) - ret = f(**kwargs) - - tm.assert_is_valid_plot_return_object(ret) - - if f is pd.plotting.bootstrap_plot: - assert "ax" not in kwargs - else: - kwargs["ax"] = fig.add_subplot(212) - - ret = f(**kwargs) - tm.assert_is_valid_plot_return_object(ret) + for ret in gen_plots(f, fig, **kwargs): + tm.assert_is_valid_plot_return_object(ret) with tm.ensure_clean(return_filelike=True) as path: plt.savefig(path) + + except Exception as err: + raise err finally: tm.close(fig) return ret +def _gen_default_plot(f, fig, **kwargs): + """ + Create plot in a default way. + """ + yield f(**kwargs) + + +def _gen_two_subplots(f, fig, **kwargs): + """ + Create plot on two subplots forcefully created. + """ + kwargs.get("ax", fig.add_subplot(211)) + yield f(**kwargs) + + if f is pd.plotting.bootstrap_plot: + assert "ax" not in kwargs + else: + kwargs["ax"] = fig.add_subplot(212) + yield f(**kwargs) + + def curpath(): pth, _ = os.path.split(os.path.abspath(__file__)) return pth diff --git a/pandas/tests/plotting/test_hist_method.py b/pandas/tests/plotting/test_hist_method.py index 49335230171c6..ab0024559333e 100644 --- a/pandas/tests/plotting/test_hist_method.py +++ b/pandas/tests/plotting/test_hist_method.py @@ -152,7 +152,8 @@ def test_hist_with_legend(self, by, expected_axes_num, expected_layout): s = Series(np.random.randn(30), index=index, name="a") s.index.name = "b" - axes = _check_plot_works(s.hist, legend=True, by=by) + # Use default_axes=True when plotting method generate subplots itself + axes = _check_plot_works(s.hist, default_axes=True, legend=True, by=by) self._check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout) self._check_legend_labels(axes, "a") @@ -332,7 +333,8 @@ def test_tight_layout(self): dtype=np.int64, ) ) - _check_plot_works(df.hist) + # Use default_axes=True when plotting method generate subplots itself + _check_plot_works(df.hist, default_axes=True) self.plt.tight_layout() tm.close() @@ -345,8 +347,10 @@ def test_hist_subplot_xrot(self): "animal": ["pig", "rabbit", "pig", "pig", "rabbit"], } ) + # Use default_axes=True when plotting method generate subplots itself axes = _check_plot_works( df.hist, + default_axes=True, filterwarnings="always", column="length", by="animal", @@ -374,9 +378,14 @@ def test_hist_column_order_unchanged(self, column, expected): index=["pig", "rabbit", "duck", "chicken", "horse"], ) - axes = _check_plot_works(df.hist, column=column, layout=(1, 3)) + # Use default_axes=True when plotting method generate subplots itself + axes = _check_plot_works( + df.hist, + default_axes=True, + column=column, + layout=(1, 3), + ) result = [axes[0, i].get_title() for i in range(3)] - assert result == expected @pytest.mark.parametrize( @@ -407,7 +416,15 @@ def test_hist_with_legend(self, by, column): index = Index(15 * ["1"] + 15 * ["2"], name="c") df = DataFrame(np.random.randn(30, 2), index=index, columns=["a", "b"]) - axes = _check_plot_works(df.hist, legend=True, by=by, column=column) + # Use default_axes=True when plotting method generate subplots itself + axes = _check_plot_works( + df.hist, + default_axes=True, + legend=True, + by=by, + column=column, + ) + self._check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout) if by is None and column is None: axes = axes[0]