Skip to content

Commit c226fc5

Browse files
authored
TST: fix warnings on multiple subplots (#37274)
1 parent 085f27b commit c226fc5

File tree

2 files changed

+84
-23
lines changed

2 files changed

+84
-23
lines changed

pandas/tests/plotting/common.py

+62-18
Original file line numberDiff line numberDiff line change
@@ -550,41 +550,85 @@ def _unpack_cycler(self, rcParams, field="color"):
550550
return [v[field] for v in rcParams["axes.prop_cycle"]]
551551

552552

553-
def _check_plot_works(f, filterwarnings="always", **kwargs):
553+
def _check_plot_works(f, filterwarnings="always", default_axes=False, **kwargs):
554+
"""
555+
Create plot and ensure that plot return object is valid.
556+
557+
Parameters
558+
----------
559+
f : func
560+
Plotting function.
561+
filterwarnings : str
562+
Warnings filter.
563+
See https://docs.python.org/3/library/warnings.html#warning-filter
564+
default_axes : bool, optional
565+
If False (default):
566+
- If `ax` not in `kwargs`, then create subplot(211) and plot there
567+
- Create new subplot(212) and plot there as well
568+
- Mind special corner case for bootstrap_plot (see `_gen_two_subplots`)
569+
If True:
570+
- Simply run plotting function with kwargs provided
571+
- All required axes instances will be created automatically
572+
- It is recommended to use it when the plotting function
573+
creates multiple axes itself. It helps avoid warnings like
574+
'UserWarning: To output multiple subplots,
575+
the figure containing the passed axes is being cleared'
576+
**kwargs
577+
Keyword arguments passed to the plotting function.
578+
579+
Returns
580+
-------
581+
Plot object returned by the last plotting.
582+
"""
554583
import matplotlib.pyplot as plt
555584

585+
if default_axes:
586+
gen_plots = _gen_default_plot
587+
else:
588+
gen_plots = _gen_two_subplots
589+
556590
ret = None
557591
with warnings.catch_warnings():
558592
warnings.simplefilter(filterwarnings)
559593
try:
560-
try:
561-
fig = kwargs["figure"]
562-
except KeyError:
563-
fig = plt.gcf()
564-
594+
fig = kwargs.get("figure", plt.gcf())
565595
plt.clf()
566596

567-
kwargs.get("ax", fig.add_subplot(211))
568-
ret = f(**kwargs)
569-
570-
tm.assert_is_valid_plot_return_object(ret)
571-
572-
if f is pd.plotting.bootstrap_plot:
573-
assert "ax" not in kwargs
574-
else:
575-
kwargs["ax"] = fig.add_subplot(212)
576-
577-
ret = f(**kwargs)
578-
tm.assert_is_valid_plot_return_object(ret)
597+
for ret in gen_plots(f, fig, **kwargs):
598+
tm.assert_is_valid_plot_return_object(ret)
579599

580600
with tm.ensure_clean(return_filelike=True) as path:
581601
plt.savefig(path)
602+
603+
except Exception as err:
604+
raise err
582605
finally:
583606
tm.close(fig)
584607

585608
return ret
586609

587610

611+
def _gen_default_plot(f, fig, **kwargs):
612+
"""
613+
Create plot in a default way.
614+
"""
615+
yield f(**kwargs)
616+
617+
618+
def _gen_two_subplots(f, fig, **kwargs):
619+
"""
620+
Create plot on two subplots forcefully created.
621+
"""
622+
kwargs.get("ax", fig.add_subplot(211))
623+
yield f(**kwargs)
624+
625+
if f is pd.plotting.bootstrap_plot:
626+
assert "ax" not in kwargs
627+
else:
628+
kwargs["ax"] = fig.add_subplot(212)
629+
yield f(**kwargs)
630+
631+
588632
def curpath():
589633
pth, _ = os.path.split(os.path.abspath(__file__))
590634
return pth

pandas/tests/plotting/test_hist_method.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def test_hist_with_legend(self, by, expected_axes_num, expected_layout):
152152
s = Series(np.random.randn(30), index=index, name="a")
153153
s.index.name = "b"
154154

155-
axes = _check_plot_works(s.hist, legend=True, by=by)
155+
# Use default_axes=True when plotting method generate subplots itself
156+
axes = _check_plot_works(s.hist, default_axes=True, legend=True, by=by)
156157
self._check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout)
157158
self._check_legend_labels(axes, "a")
158159

@@ -332,7 +333,8 @@ def test_tight_layout(self):
332333
dtype=np.int64,
333334
)
334335
)
335-
_check_plot_works(df.hist)
336+
# Use default_axes=True when plotting method generate subplots itself
337+
_check_plot_works(df.hist, default_axes=True)
336338
self.plt.tight_layout()
337339

338340
tm.close()
@@ -345,8 +347,10 @@ def test_hist_subplot_xrot(self):
345347
"animal": ["pig", "rabbit", "pig", "pig", "rabbit"],
346348
}
347349
)
350+
# Use default_axes=True when plotting method generate subplots itself
348351
axes = _check_plot_works(
349352
df.hist,
353+
default_axes=True,
350354
filterwarnings="always",
351355
column="length",
352356
by="animal",
@@ -374,9 +378,14 @@ def test_hist_column_order_unchanged(self, column, expected):
374378
index=["pig", "rabbit", "duck", "chicken", "horse"],
375379
)
376380

377-
axes = _check_plot_works(df.hist, column=column, layout=(1, 3))
381+
# Use default_axes=True when plotting method generate subplots itself
382+
axes = _check_plot_works(
383+
df.hist,
384+
default_axes=True,
385+
column=column,
386+
layout=(1, 3),
387+
)
378388
result = [axes[0, i].get_title() for i in range(3)]
379-
380389
assert result == expected
381390

382391
@pytest.mark.parametrize(
@@ -407,7 +416,15 @@ def test_hist_with_legend(self, by, column):
407416
index = Index(15 * ["1"] + 15 * ["2"], name="c")
408417
df = DataFrame(np.random.randn(30, 2), index=index, columns=["a", "b"])
409418

410-
axes = _check_plot_works(df.hist, legend=True, by=by, column=column)
419+
# Use default_axes=True when plotting method generate subplots itself
420+
axes = _check_plot_works(
421+
df.hist,
422+
default_axes=True,
423+
legend=True,
424+
by=by,
425+
column=column,
426+
)
427+
411428
self._check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout)
412429
if by is None and column is None:
413430
axes = axes[0]

0 commit comments

Comments
 (0)