diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index d35242ada21e9..85d03ea17bf42 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -57,7 +57,6 @@ assert_indexing_slices_equivalent, assert_interval_array_equal, assert_is_sorted, - assert_is_valid_plot_return_object, assert_metadata_equivalent, assert_numpy_array_equal, assert_period_array_equal, @@ -558,7 +557,6 @@ def shares_memory(left, right) -> bool: "assert_indexing_slices_equivalent", "assert_interval_array_equal", "assert_is_sorted", - "assert_is_valid_plot_return_object", "assert_metadata_equivalent", "assert_numpy_array_equal", "assert_period_array_equal", diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index 430840711122a..1127a4512643c 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -429,28 +429,6 @@ def assert_attr_equal(attr: str, left, right, obj: str = "Attributes") -> None: return None -def assert_is_valid_plot_return_object(objs) -> None: - from matplotlib.artist import Artist - from matplotlib.axes import Axes - - if isinstance(objs, (Series, np.ndarray)): - if isinstance(objs, Series): - objs = objs._values - for el in objs.ravel(): - msg = ( - "one of 'objs' is not a matplotlib Axes instance, " - f"type encountered {type(el).__name__!r}" - ) - assert isinstance(el, (Axes, dict)), msg - else: - msg = ( - "objs is neither an ndarray of Artist instances nor a single " - "ArtistArtist instance, tuple, or dict, 'objs' is a " - f"{type(objs).__name__!r}" - ) - assert isinstance(objs, (Artist, tuple, dict)), msg - - def assert_is_sorted(seq) -> None: """Assert that the sequence is sorted.""" if isinstance(seq, (Index, Series)): diff --git a/pandas/tests/plotting/common.py b/pandas/tests/plotting/common.py index 5a46cdcb051b6..d8c49d6d47f28 100644 --- a/pandas/tests/plotting/common.py +++ b/pandas/tests/plotting/common.py @@ -76,8 +76,6 @@ def _check_data(xp, rs): xp : matplotlib Axes object rs : matplotlib Axes object """ - import matplotlib.pyplot as plt - xp_lines = xp.get_lines() rs_lines = rs.get_lines() @@ -87,8 +85,6 @@ def _check_data(xp, rs): rsdata = rsl.get_xydata() tm.assert_almost_equal(xpdata, rsdata) - plt.close("all") - def _check_visible(collections, visible=True): """ @@ -495,6 +491,28 @@ def get_y_axis(ax): return ax._shared_axes["y"] +def assert_is_valid_plot_return_object(objs) -> None: + from matplotlib.artist import Artist + from matplotlib.axes import Axes + + if isinstance(objs, (Series, np.ndarray)): + if isinstance(objs, Series): + objs = objs._values + for el in objs.reshape(-1): + msg = ( + "one of 'objs' is not a matplotlib Axes instance, " + f"type encountered {type(el).__name__!r}" + ) + assert isinstance(el, (Axes, dict)), msg + else: + msg = ( + "objs is neither an ndarray of Artist instances nor a single " + "ArtistArtist instance, tuple, or dict, 'objs' is a " + f"{type(objs).__name__!r}" + ) + assert isinstance(objs, (Artist, tuple, dict)), msg + + def _check_plot_works(f, default_axes=False, **kwargs): """ Create plot and ensure that plot return object is valid. @@ -530,15 +548,11 @@ def _check_plot_works(f, default_axes=False, **kwargs): gen_plots = _gen_two_subplots ret = None - try: - fig = kwargs.get("figure", plt.gcf()) - plt.clf() - - for ret in gen_plots(f, fig, **kwargs): - tm.assert_is_valid_plot_return_object(ret) + fig = kwargs.get("figure", plt.gcf()) + fig.clf() - finally: - plt.close(fig) + for ret in gen_plots(f, fig, **kwargs): + assert_is_valid_plot_return_object(ret) return ret