Skip to content

Commit c71c4a5

Browse files
JustinZhengBCproost
authored andcommitted
BUG-26214 fix colors parameter in DataFrame.boxplot (pandas-dev#26456)
1 parent 1c1eb80 commit c71c4a5

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

doc/source/whatsnew/v1.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ Plotting
231231
- Bug in :meth:`DataFrame.plot` producing incorrect legend markers when plotting multiple series on the same axis (:issue:`18222`)
232232
- Bug in :meth:`DataFrame.plot` when ``kind='box'`` and data contains datetime or timedelta data. These types are now automatically dropped (:issue:`22799`)
233233
- Bug in :meth:`DataFrame.plot.line` and :meth:`DataFrame.plot.area` produce wrong xlim in x-axis (:issue:`27686`, :issue:`25160`, :issue:`24784`)
234+
- Bug where :meth:`DataFrame.boxplot` would not accept a `color` parameter like `DataFrame.plot.box` (:issue:`26214`)
234235
- :func:`set_option` now validates that the plot backend provided to ``'plotting.backend'`` implements the backend when the option is set, rather than when a plot is created (:issue:`28163`)
235236

236237
Groupby/resample/rolling

pandas/plotting/_matplotlib/boxplot.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from matplotlib.artist import setp
55
import numpy as np
66

7+
from pandas.core.dtypes.common import is_dict_like
78
from pandas.core.dtypes.generic import ABCSeries
89
from pandas.core.dtypes.missing import remove_na_arraylike
910

@@ -250,13 +251,38 @@ def boxplot(
250251
def _get_colors():
251252
# num_colors=3 is required as method maybe_color_bp takes the colors
252253
# in positions 0 and 2.
253-
return _get_standard_colors(color=kwds.get("color"), num_colors=3)
254+
# if colors not provided, use same defaults as DataFrame.plot.box
255+
result = _get_standard_colors(num_colors=3)
256+
result = np.take(result, [0, 0, 2])
257+
result = np.append(result, "k")
258+
259+
colors = kwds.pop("color", None)
260+
if colors:
261+
if is_dict_like(colors):
262+
# replace colors in result array with user-specified colors
263+
# taken from the colors dict parameter
264+
# "boxes" value placed in position 0, "whiskers" in 1, etc.
265+
valid_keys = ["boxes", "whiskers", "medians", "caps"]
266+
key_to_index = dict(zip(valid_keys, range(4)))
267+
for key, value in colors.items():
268+
if key in valid_keys:
269+
result[key_to_index[key]] = value
270+
else:
271+
raise ValueError(
272+
"color dict contains invalid "
273+
"key '{0}' "
274+
"The key must be either {1}".format(key, valid_keys)
275+
)
276+
else:
277+
result.fill(colors)
278+
279+
return result
254280

255281
def maybe_color_bp(bp):
256-
if "color" not in kwds:
257-
setp(bp["boxes"], color=colors[0], alpha=1)
258-
setp(bp["whiskers"], color=colors[0], alpha=1)
259-
setp(bp["medians"], color=colors[2], alpha=1)
282+
setp(bp["boxes"], color=colors[0], alpha=1)
283+
setp(bp["whiskers"], color=colors[1], alpha=1)
284+
setp(bp["medians"], color=colors[2], alpha=1)
285+
setp(bp["caps"], color=colors[3], alpha=1)
260286

261287
def plot_group(keys, values, ax):
262288
keys = [pprint_thing(x) for x in keys]

pandas/tests/plotting/test_boxplot_method.py

+28
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,34 @@ def test_boxplot_numeric_data(self):
175175
ax = df.plot(kind="box")
176176
assert [x.get_text() for x in ax.get_xticklabels()] == ["b", "c"]
177177

178+
@pytest.mark.parametrize(
179+
"colors_kwd, expected",
180+
[
181+
(
182+
dict(boxes="r", whiskers="b", medians="g", caps="c"),
183+
dict(boxes="r", whiskers="b", medians="g", caps="c"),
184+
),
185+
(dict(boxes="r"), dict(boxes="r")),
186+
("r", dict(boxes="r", whiskers="r", medians="r", caps="r")),
187+
],
188+
)
189+
def test_color_kwd(self, colors_kwd, expected):
190+
# GH: 26214
191+
df = DataFrame(random.rand(10, 2))
192+
result = df.boxplot(color=colors_kwd, return_type="dict")
193+
for k, v in expected.items():
194+
assert result[k][0].get_color() == v
195+
196+
@pytest.mark.parametrize(
197+
"dict_colors, msg",
198+
[(dict(boxes="r", invalid_key="r"), "invalid key 'invalid_key'")],
199+
)
200+
def test_color_kwd_errors(self, dict_colors, msg):
201+
# GH: 26214
202+
df = DataFrame(random.rand(10, 2))
203+
with pytest.raises(ValueError, match=msg):
204+
df.boxplot(color=dict_colors, return_type="dict")
205+
178206

179207
@td.skip_if_no_mpl
180208
class TestDataFrameGroupByPlots(TestPlotBase):

0 commit comments

Comments
 (0)