Skip to content

Commit 49c20a4

Browse files
committed
Fixes issue pandas-dev#8193
1 parent ee59b7d commit 49c20a4

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

pandas/plotting/_matplotlib/core.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,10 @@ def _apply_style_colors(self, colors, kwds, col_num, label):
743743
has_color = "color" in kwds or self.colormap is not None
744744
nocolor_style = style is None or re.match("[a-z]+", style) is None
745745
if (has_color or self.subplots) and nocolor_style:
746-
kwds["color"] = colors[col_num % len(colors)]
746+
if isinstance(colors, dict):
747+
kwds["color"] = colors[label]
748+
else:
749+
kwds["color"] = colors[col_num % len(colors)]
747750
return style, kwds
748751

749752
def _get_colors(self, num_colors=None, color_kwds="color"):
@@ -1356,12 +1359,13 @@ def _make_plot(self):
13561359

13571360
pos_prior = neg_prior = np.zeros(len(self.data))
13581361
K = self.nseries
1359-
13601362
for i, (label, y) in enumerate(self._iter_data(fillna=0)):
13611363
ax = self._get_ax(i)
13621364
kwds = self.kwds.copy()
13631365
if self._is_series:
13641366
kwds["color"] = colors
1367+
elif isinstance(colors, dict):
1368+
kwds["color"] = colors[label]
13651369
else:
13661370
kwds["color"] = colors[i % ncolors]
13671371

pandas/plotting/_matplotlib/style.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ def _get_standard_colors(
2727
warnings.warn(
2828
"'color' and 'colormap' cannot be used simultaneously. Using 'color'"
2929
)
30-
colors = list(color) if is_list_like(color) else color
30+
colors = (
31+
list(color)
32+
if is_list_like(color) and not isinstance(color, dict)
33+
else color
34+
)
3135
else:
3236
if color_type == "default":
3337
# need to call list() on the result to copy so we don't

pandas/tests/plotting/test_misc.py

+21
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,24 @@ def test_get_standard_colors_no_appending(self):
417417
color_list = cm.gnuplot(np.linspace(0, 1, 16))
418418
p = df.A.plot.bar(figsize=(16, 7), color=color_list)
419419
assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor()
420+
421+
@pytest.mark.slow
422+
def test_dictionary_color(self):
423+
# issue-8193
424+
# Test plot color dictionary format
425+
data_files = ["a", "b"]
426+
427+
expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)]
428+
429+
df1 = DataFrame(np.random.rand(2, 2), columns=data_files)
430+
dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)}
431+
432+
# Bar color test
433+
ax = df1.plot(kind="bar", color=dic_color)
434+
colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]]
435+
assert all(color == expected[index] for index, color in enumerate(colors))
436+
437+
# Line color test
438+
ax = df1.plot(kind="line", color=dic_color)
439+
colors = [rect.get_color() for rect in ax.get_lines()[0:2]]
440+
assert all(color == expected[index] for index, color in enumerate(colors))

0 commit comments

Comments
 (0)