Skip to content

Commit 6b1579b

Browse files
committed
Some refinements
- Added tests - Added legend to _grouped_hist - DataFrame.hist now respects kwarg 'label' if specified.
1 parent cbbff78 commit 6b1579b

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

pandas/plotting/_matplotlib/hist.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def _grouped_hist(
227227
xrot=None,
228228
ylabelsize=None,
229229
yrot=None,
230+
legend=False,
230231
**kwargs,
231232
):
232233
"""
@@ -252,8 +253,19 @@ def _grouped_hist(
252253
collection of Matplotlib Axes
253254
"""
254255

256+
if legend and "label" not in kwargs:
257+
if isinstance(data, ABCDataFrame):
258+
if column is None:
259+
kwargs["label"] = data.columns
260+
else:
261+
kwargs["label"] = column
262+
else:
263+
kwargs["label"] = data.name
264+
255265
def plot_group(group, ax):
256266
ax.hist(group.dropna().values, bins=bins, **kwargs)
267+
if legend:
268+
ax.legend()
257269

258270
if xrot is None:
259271
xrot = rot
@@ -340,6 +352,7 @@ def hist_series(
340352
xrot=xrot,
341353
ylabelsize=ylabelsize,
342354
yrot=yrot,
355+
legend=legend,
343356
**kwds,
344357
)
345358

@@ -383,6 +396,7 @@ def hist_frame(
383396
xrot=xrot,
384397
ylabelsize=ylabelsize,
385398
yrot=yrot,
399+
legend=legend,
386400
**kwds,
387401
)
388402
return axes
@@ -408,16 +422,17 @@ def hist_frame(
408422
)
409423
_axes = _flatten(axes)
410424

425+
can_set_label = "label" not in kwds
426+
411427
for i, col in enumerate(com.try_sort(data.columns)):
412428
ax = _axes[i]
413-
if legend and "label" not in kwds:
429+
if legend and can_set_label:
414430
kwds["label"] = col
415431
ax.hist(data[col].dropna().values, bins=bins, **kwds)
416432
ax.set_title(col)
417433
ax.grid(grid)
418434
if legend:
419435
ax.legend()
420-
kwds.pop("label")
421436

422437
_set_ticks_props(
423438
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot

pandas/tests/plotting/test_groupby.py

+15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import numpy as np
5+
import pytest
56

67
import pandas.util._test_decorators as td
78

@@ -65,3 +66,17 @@ def test_plot_kwargs(self):
6566

6667
res = df.groupby("z").plot.scatter(x="x", y="y")
6768
assert len(res["a"].collections) == 1
69+
70+
71+
@td.skip_if_no_mpl
72+
@pytest.mark.parametrize("column", [None, "B"])
73+
@pytest.mark.parametrize("label", [None, "D"])
74+
def test_hist_with_legend(column, label):
75+
df = DataFrame(np.random.randn(30, 2), columns=["A", "B"])
76+
df["C"] = 15 * ["a"] + 15 * ["b"]
77+
g = df.groupby("C")
78+
g.hist(column=column, label=label, legend=True)
79+
tm.close()
80+
if column != "B":
81+
g["A"].hist(label=label, legend=True)
82+
tm.close()

pandas/tests/plotting/test_hist_method.py

+11
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,14 @@ def test_axis_share_xy(self):
460460

461461
assert ax1._shared_y_axes.joined(ax1, ax2)
462462
assert ax2._shared_y_axes.joined(ax1, ax2)
463+
464+
465+
@pytest.mark.slow
466+
@pytest.mark.parametrize("by", [None, "C"])
467+
@pytest.mark.parametrize("column", [None, "B"])
468+
@pytest.mark.parametrize("label", [None, "D"])
469+
def test_hist_with_legend(by, column, label):
470+
df = DataFrame(np.random.randn(30, 2), columns=["A", "B"])
471+
df["C"] = 15 * ["a"] + 15 * ["b"]
472+
df = df.set_index("C")
473+
_check_plot_works(df.hist, by=by, column=column, label=label, legend=True)

0 commit comments

Comments
 (0)