|
4 | 4 | from matplotlib.artist import setp
|
5 | 5 | import numpy as np
|
6 | 6 |
|
| 7 | +from pandas.core.dtypes.common import is_dict_like |
7 | 8 | from pandas.core.dtypes.generic import ABCSeries
|
8 | 9 | from pandas.core.dtypes.missing import remove_na_arraylike
|
9 | 10 |
|
@@ -250,13 +251,38 @@ def boxplot(
|
250 | 251 | def _get_colors():
|
251 | 252 | # num_colors=3 is required as method maybe_color_bp takes the colors
|
252 | 253 | # in positions 0 and 2.
|
253 |
| - return _get_standard_colors(color=kwds.get("color"), num_colors=3) |
| 254 | + # if colors not provided, use same defaults as DataFrame.plot.box |
| 255 | + result = _get_standard_colors(num_colors=3) |
| 256 | + result = np.take(result, [0, 0, 2]) |
| 257 | + result = np.append(result, "k") |
| 258 | + |
| 259 | + colors = kwds.pop("color", None) |
| 260 | + if colors: |
| 261 | + if is_dict_like(colors): |
| 262 | + # replace colors in result array with user-specified colors |
| 263 | + # taken from the colors dict parameter |
| 264 | + # "boxes" value placed in position 0, "whiskers" in 1, etc. |
| 265 | + valid_keys = ["boxes", "whiskers", "medians", "caps"] |
| 266 | + key_to_index = dict(zip(valid_keys, range(4))) |
| 267 | + for key, value in colors.items(): |
| 268 | + if key in valid_keys: |
| 269 | + result[key_to_index[key]] = value |
| 270 | + else: |
| 271 | + raise ValueError( |
| 272 | + "color dict contains invalid " |
| 273 | + "key '{0}' " |
| 274 | + "The key must be either {1}".format(key, valid_keys) |
| 275 | + ) |
| 276 | + else: |
| 277 | + result.fill(colors) |
| 278 | + |
| 279 | + return result |
254 | 280 |
|
255 | 281 | def maybe_color_bp(bp):
|
256 |
| - if "color" not in kwds: |
257 |
| - setp(bp["boxes"], color=colors[0], alpha=1) |
258 |
| - setp(bp["whiskers"], color=colors[0], alpha=1) |
259 |
| - setp(bp["medians"], color=colors[2], alpha=1) |
| 282 | + setp(bp["boxes"], color=colors[0], alpha=1) |
| 283 | + setp(bp["whiskers"], color=colors[1], alpha=1) |
| 284 | + setp(bp["medians"], color=colors[2], alpha=1) |
| 285 | + setp(bp["caps"], color=colors[3], alpha=1) |
260 | 286 |
|
261 | 287 | def plot_group(keys, values, ax):
|
262 | 288 | keys = [pprint_thing(x) for x in keys]
|
|
0 commit comments