Skip to content

Commit 0c34b9b

Browse files
committed
BUG: Fix .hist and .plot.hist when passing existing figure (#37278)
1 parent 290c58a commit 0c34b9b

File tree

7 files changed

+55
-5
lines changed

7 files changed

+55
-5
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Other enhancements
214214
- :class:`Index` with object dtype supports division and multiplication (:issue:`34160`)
215215
- :meth:`DataFrame.explode` and :meth:`Series.explode` now support exploding of sets (:issue:`35614`)
216216
- :meth:`DataFrame.hist` now supports time series (datetime) data (:issue:`32590`)
217+
- :meth:`DataFrame.hist` and :meth:`DataFrame.plot.hist` can now be called with an existing matplotlib ``Figure`` object via added ``figure`` argument (:issue:`37278`)
217218
- ``Styler`` now allows direct CSS class name addition to individual data cells (:issue:`36159`)
218219
- :meth:`Rolling.mean()` and :meth:`Rolling.sum()` use Kahan summation to calculate the mean to avoid numerical problems (:issue:`10319`, :issue:`11645`, :issue:`13254`, :issue:`32761`, :issue:`36031`)
219220
- :meth:`DatetimeIndex.searchsorted`, :meth:`TimedeltaIndex.searchsorted`, :meth:`PeriodIndex.searchsorted`, and :meth:`Series.searchsorted` with datetimelike dtypes will now try to cast string arguments (listlike and scalar) to the matching datetimelike type (:issue:`36346`)

pandas/plotting/_core.py

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from pandas.core.base import PandasObject
1313

1414
if TYPE_CHECKING:
15+
from matplotlib.figure import Figure
16+
1517
from pandas import DataFrame
1618

1719

@@ -107,6 +109,7 @@ def hist_frame(
107109
xrot: Optional[float] = None,
108110
ylabelsize: Optional[int] = None,
109111
yrot: Optional[float] = None,
112+
figure: Optional["Figure"] = None,
110113
ax=None,
111114
sharex: bool = False,
112115
sharey: bool = False,
@@ -146,6 +149,8 @@ def hist_frame(
146149
yrot : float, default None
147150
Rotation of y axis labels. For example, a value of 90 displays the
148151
y labels rotated 90 degrees clockwise.
152+
figure : Matplotlib Figure object, default None
153+
The figure to plot the histogram on.
149154
ax : Matplotlib axes object, default None
150155
The axes to plot the histogram on.
151156
sharex : bool, default True if ax is None else False
@@ -217,6 +222,7 @@ def hist_frame(
217222
xrot=xrot,
218223
ylabelsize=ylabelsize,
219224
yrot=yrot,
225+
figure=figure,
220226
ax=ax,
221227
sharex=sharex,
222228
sharey=sharey,

pandas/plotting/_matplotlib/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def plot(data, kind, **kwargs):
5151
# work)
5252
import matplotlib.pyplot as plt
5353

54+
if kwargs.get("figure"):
55+
kwargs["fig"] = kwargs.get("figure")
56+
kwargs["ax"] = kwargs["figure"].gca()
57+
kwargs.pop("reuse_plot", None)
5458
if kwargs.pop("reuse_plot", False):
5559
ax = kwargs.get("ax")
5660
if ax is None and len(plt.get_fignums()) > 0:

pandas/plotting/_matplotlib/core.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -325,12 +325,16 @@ def _setup_subplots(self):
325325
sharex=self.sharex,
326326
sharey=self.sharey,
327327
figsize=self.figsize,
328+
figure=self.fig,
328329
ax=self.ax,
329330
layout=self.layout,
330331
layout_type=self._layout_type,
331332
)
332333
else:
333-
if self.ax is None:
334+
if self.fig is not None:
335+
fig = self.fig
336+
axes = fig.add_subplot(111)
337+
elif self.ax is None:
334338
fig = self.plt.figure(figsize=self.figsize)
335339
axes = fig.add_subplot(111)
336340
else:

pandas/plotting/_matplotlib/hist.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Optional
22

33
import numpy as np
44

@@ -16,6 +16,7 @@
1616

1717
if TYPE_CHECKING:
1818
from matplotlib.axes import Axes
19+
from matplotlib.figure import Figure
1920

2021

2122
class HistPlot(LinePlot):
@@ -181,6 +182,7 @@ def _grouped_plot(
181182
column=None,
182183
by=None,
183184
numeric_only=True,
185+
figure: Optional["Figure"] = None,
184186
figsize=None,
185187
sharex=True,
186188
sharey=True,
@@ -203,7 +205,13 @@ def _grouped_plot(
203205

204206
naxes = len(grouped)
205207
fig, axes = create_subplots(
206-
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
208+
naxes=naxes,
209+
figure=figure,
210+
figsize=figsize,
211+
sharex=sharex,
212+
sharey=sharey,
213+
ax=ax,
214+
layout=layout,
207215
)
208216

209217
_axes = flatten_axes(axes)
@@ -222,6 +230,7 @@ def _grouped_hist(
222230
data,
223231
column=None,
224232
by=None,
233+
figure=None,
225234
ax=None,
226235
bins=50,
227236
figsize=None,
@@ -245,6 +254,7 @@ def _grouped_hist(
245254
data : Series/DataFrame
246255
column : object, optional
247256
by : object, optional
257+
figure: figure, optional
248258
ax : axes, optional
249259
bins : int, default 50
250260
figsize : tuple, optional
@@ -282,6 +292,7 @@ def plot_group(group, ax):
282292
data,
283293
column=column,
284294
by=by,
295+
figure=figure,
285296
sharex=sharex,
286297
sharey=sharey,
287298
ax=ax,
@@ -381,6 +392,7 @@ def hist_frame(
381392
xrot=None,
382393
ylabelsize=None,
383394
yrot=None,
395+
figure=None,
384396
ax=None,
385397
sharex=False,
386398
sharey=False,
@@ -397,6 +409,7 @@ def hist_frame(
397409
data,
398410
column=column,
399411
by=by,
412+
figure=figure,
400413
ax=ax,
401414
grid=grid,
402415
figsize=figsize,
@@ -430,6 +443,7 @@ def hist_frame(
430443

431444
fig, axes = create_subplots(
432445
naxes=naxes,
446+
figure=figure,
433447
ax=ax,
434448
squeeze=False,
435449
sharex=sharex,

pandas/plotting/_matplotlib/tools.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# being a bit too dynamic
22
from math import ceil
3-
from typing import TYPE_CHECKING, Iterable, List, Sequence, Tuple, Union
3+
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple, Union
44
import warnings
55

66
import matplotlib.table
@@ -17,6 +17,7 @@
1717
if TYPE_CHECKING:
1818
from matplotlib.axes import Axes
1919
from matplotlib.axis import Axis
20+
from matplotlib.figure import Figure
2021
from matplotlib.lines import Line2D
2122
from matplotlib.table import Table
2223

@@ -106,6 +107,7 @@ def create_subplots(
106107
sharey: bool = False,
107108
squeeze: bool = True,
108109
subplot_kw=None,
110+
figure: Optional["Figure"] = None,
109111
ax=None,
110112
layout=None,
111113
layout_type: str = "box",
@@ -145,6 +147,9 @@ def create_subplots(
145147
Dict with keywords passed to the add_subplot() call used to create each
146148
subplots.
147149
150+
figure : Matplotlib figure object, optional
151+
Existing figure to be used for plotting.
152+
148153
ax : Matplotlib axis object, optional
149154
150155
layout : tuple
@@ -190,7 +195,9 @@ def create_subplots(
190195
if subplot_kw is None:
191196
subplot_kw = {}
192197

193-
if ax is None:
198+
if figure is not None:
199+
fig = figure
200+
elif ax is None:
194201
fig = plt.figure(**fig_kw)
195202
else:
196203
if is_list_like(ax):

pandas/tests/plotting/test_hist_method.py

+14
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ def test_hist_with_legend_raises(self, by):
152152
with pytest.raises(ValueError, match="Cannot use both legend and label"):
153153
s.hist(legend=True, by=by, label="c")
154154

155+
def test_hist_with_figure_argument(self):
156+
# GH37278
157+
index = 15 * ["1"] + 15 * ["2"]
158+
s = Series(np.random.randn(30), index=index, name="a")
159+
_check_plot_works(s.hist, figure=self.plt.figure())
160+
_check_plot_works(s.plot.hist, figure=self.plt.figure())
161+
155162

156163
@td.skip_if_no_mpl
157164
class TestDataFramePlots(TestPlotBase):
@@ -395,6 +402,13 @@ def test_hist_with_legend_raises(self, by, column):
395402
with pytest.raises(ValueError, match="Cannot use both legend and label"):
396403
df.hist(legend=True, by=by, column=column, label="d")
397404

405+
def test_hist_with_figure_argument(self):
406+
# GH37278
407+
index = Index(15 * ["1"] + 15 * ["2"], name="c")
408+
df = DataFrame(np.random.randn(30, 2), index=index, columns=["a", "b"])
409+
_check_plot_works(df.hist, figure=self.plt.figure())
410+
_check_plot_works(df.plot.hist, figure=self.plt.figure())
411+
398412

399413
@td.skip_if_no_mpl
400414
class TestDataFrameGroupByPlots(TestPlotBase):

0 commit comments

Comments
 (0)