Skip to content

Commit 436ca0d

Browse files
committed
Fixes issue #8193
1 parent c4de906 commit 436ca0d

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

pandas/plotting/_matplotlib/core.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,10 @@ def _apply_style_colors(self, colors, kwds, col_num, label):
743743
has_color = "color" in kwds or self.colormap is not None
744744
nocolor_style = style is None or re.match("[a-z]+", style) is None
745745
if (has_color or self.subplots) and nocolor_style:
746-
kwds["color"] = colors[col_num % len(colors)]
746+
if isinstance(colors, dict):
747+
kwds["color"] = colors[label]
748+
else:
749+
kwds["color"] = colors[col_num % len(colors)]
747750
return style, kwds
748751

749752
def _get_colors(self, num_colors=None, color_kwds="color"):
@@ -1356,12 +1359,15 @@ def _make_plot(self):
13561359

13571360
pos_prior = neg_prior = np.zeros(len(self.data))
13581361
K = self.nseries
1359-
13601362
for i, (label, y) in enumerate(self._iter_data(fillna=0)):
13611363
ax = self._get_ax(i)
13621364
kwds = self.kwds.copy()
13631365
if self._is_series:
13641366
kwds["color"] = colors
1367+
1368+
elif isinstance(colors, dict):
1369+
kwds["color"] = colors[label]
1370+
13651371
else:
13661372
kwds["color"] = colors[i % ncolors]
13671373

pandas/plotting/_matplotlib/style.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ def _get_standard_colors(
2727
warnings.warn(
2828
"'color' and 'colormap' cannot be used " "simultaneously. Using 'color'"
2929
)
30-
colors = list(color) if is_list_like(color) else color
30+
colors = (
31+
list(color)
32+
if is_list_like(color) and not isinstance(color, dict)
33+
else color
34+
)
3135
else:
3236
if color_type == "default":
3337
# need to call list() on the result to copy so we don't

pandas/tests/plotting/test_misc.py

+20
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,23 @@ def test_get_standard_colors_no_appending(self):
417417
color_list = cm.gnuplot(np.linspace(0, 1, 16))
418418
p = df.A.plot.bar(figsize=(16, 7), color=color_list)
419419
assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor()
420+
421+
@pytest.mark.slow
422+
def test_dictionary_color(self):
423+
# Test plot color dictionary format
424+
data_files = ["a", "b"]
425+
426+
expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)]
427+
428+
df1 = DataFrame(np.random.rand(2, 2), columns=data_files)
429+
dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)}
430+
431+
# Bar color test
432+
ax = df1.plot(kind="bar", color=dic_color)
433+
colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]]
434+
assert all(color == expected[index] for index, color in enumerate(colors))
435+
436+
# Line color test
437+
ax = df1.plot(kind="line", color=dic_color)
438+
colors = [rect.get_color() for rect in ax.get_lines()[0:2]]
439+
assert all(color == expected[index] for index, color in enumerate(colors))

0 commit comments

Comments
 (0)