Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 36f026d

Browse files
authoredNov 4, 2020
CLN: clean color selection in _matplotlib/style (#37203)
1 parent e5cbaec commit 36f026d

File tree

2 files changed

+384
-53
lines changed

2 files changed

+384
-53
lines changed
 

‎pandas/plotting/_matplotlib/style.py

Lines changed: 227 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
# being a bit too dynamic
1+
from typing import (
2+
TYPE_CHECKING,
3+
Collection,
4+
Dict,
5+
Iterator,
6+
List,
7+
Optional,
8+
Sequence,
9+
Union,
10+
cast,
11+
)
212
import warnings
313

414
import matplotlib.cm as cm
@@ -9,92 +19,256 @@
919

1020
import pandas.core.common as com
1121

22+
if TYPE_CHECKING:
23+
from matplotlib.colors import Colormap
24+
25+
26+
Color = Union[str, Sequence[float]]
27+
1228

1329
def get_standard_colors(
14-
num_colors: int, colormap=None, color_type: str = "default", color=None
30+
num_colors: int,
31+
colormap: Optional["Colormap"] = None,
32+
color_type: str = "default",
33+
color: Optional[Union[Dict[str, Color], Color, Collection[Color]]] = None,
1534
):
16-
import matplotlib.pyplot as plt
35+
"""
36+
Get standard colors based on `colormap`, `color_type` or `color` inputs.
37+
38+
Parameters
39+
----------
40+
num_colors : int
41+
Minimum number of colors to be returned.
42+
Ignored if `color` is a dictionary.
43+
colormap : :py:class:`matplotlib.colors.Colormap`, optional
44+
Matplotlib colormap.
45+
When provided, the resulting colors will be derived from the colormap.
46+
color_type : {"default", "random"}, optional
47+
Type of colors to derive. Used if provided `color` and `colormap` are None.
48+
Ignored if either `color` or `colormap` are not None.
49+
color : dict or str or sequence, optional
50+
Color(s) to be used for deriving sequence of colors.
51+
Can be either be a dictionary, or a single color (single color string,
52+
or sequence of floats representing a single color),
53+
or a sequence of colors.
54+
55+
Returns
56+
-------
57+
dict or list
58+
Standard colors. Can either be a mapping if `color` was a dictionary,
59+
or a list of colors with a length of `num_colors` or more.
60+
61+
Warns
62+
-----
63+
UserWarning
64+
If both `colormap` and `color` are provided.
65+
Parameter `color` will override.
66+
"""
67+
if isinstance(color, dict):
68+
return color
69+
70+
colors = _derive_colors(
71+
color=color,
72+
colormap=colormap,
73+
color_type=color_type,
74+
num_colors=num_colors,
75+
)
76+
77+
return _cycle_colors(colors, num_colors=num_colors)
78+
79+
80+
def _derive_colors(
81+
*,
82+
color: Optional[Union[Color, Collection[Color]]],
83+
colormap: Optional[Union[str, "Colormap"]],
84+
color_type: str,
85+
num_colors: int,
86+
) -> List[Color]:
87+
"""
88+
Derive colors from either `colormap`, `color_type` or `color` inputs.
89+
90+
Get a list of colors either from `colormap`, or from `color`,
91+
or from `color_type` (if both `colormap` and `color` are None).
92+
93+
Parameters
94+
----------
95+
color : str or sequence, optional
96+
Color(s) to be used for deriving sequence of colors.
97+
Can be either be a single color (single color string, or sequence of floats
98+
representing a single color), or a sequence of colors.
99+
colormap : :py:class:`matplotlib.colors.Colormap`, optional
100+
Matplotlib colormap.
101+
When provided, the resulting colors will be derived from the colormap.
102+
color_type : {"default", "random"}, optional
103+
Type of colors to derive. Used if provided `color` and `colormap` are None.
104+
Ignored if either `color` or `colormap`` are not None.
105+
num_colors : int
106+
Number of colors to be extracted.
17107
108+
Returns
109+
-------
110+
list
111+
List of colors extracted.
112+
113+
Warns
114+
-----
115+
UserWarning
116+
If both `colormap` and `color` are provided.
117+
Parameter `color` will override.
118+
"""
18119
if color is None and colormap is not None:
19-
if isinstance(colormap, str):
20-
cmap = colormap
21-
colormap = cm.get_cmap(colormap)
22-
if colormap is None:
23-
raise ValueError(f"Colormap {cmap} is not recognized")
24-
colors = [colormap(num) for num in np.linspace(0, 1, num=num_colors)]
120+
return _get_colors_from_colormap(colormap, num_colors=num_colors)
25121
elif color is not None:
26122
if colormap is not None:
27123
warnings.warn(
28124
"'color' and 'colormap' cannot be used simultaneously. Using 'color'"
29125
)
30-
colors = (
31-
list(color)
32-
if is_list_like(color) and not isinstance(color, dict)
33-
else color
34-
)
126+
return _get_colors_from_color(color)
35127
else:
36-
if color_type == "default":
37-
# need to call list() on the result to copy so we don't
38-
# modify the global rcParams below
39-
try:
40-
colors = [c["color"] for c in list(plt.rcParams["axes.prop_cycle"])]
41-
except KeyError:
42-
colors = list(plt.rcParams.get("axes.color_cycle", list("bgrcmyk")))
43-
if isinstance(colors, str):
44-
colors = list(colors)
45-
46-
colors = colors[0:num_colors]
47-
elif color_type == "random":
48-
49-
def random_color(column):
50-
""" Returns a random color represented as a list of length 3"""
51-
# GH17525 use common._random_state to avoid resetting the seed
52-
rs = com.random_state(column)
53-
return rs.rand(3).tolist()
54-
55-
colors = [random_color(num) for num in range(num_colors)]
56-
else:
57-
raise ValueError("color_type must be either 'default' or 'random'")
128+
return _get_colors_from_color_type(color_type, num_colors=num_colors)
58129

59-
if isinstance(colors, str) and _is_single_color(colors):
60-
# GH #36972
61-
colors = [colors]
62130

63-
# Append more colors by cycling if there is not enough color.
64-
# Extra colors will be ignored by matplotlib if there are more colors
65-
# than needed and nothing needs to be done here.
131+
def _cycle_colors(colors: List[Color], num_colors: int) -> List[Color]:
132+
"""Append more colors by cycling if there is not enough color.
133+
134+
Extra colors will be ignored by matplotlib if there are more colors
135+
than needed and nothing needs to be done here.
136+
"""
66137
if len(colors) < num_colors:
67-
try:
68-
multiple = num_colors // len(colors) - 1
69-
except ZeroDivisionError:
70-
raise ValueError("Invalid color argument: ''")
138+
multiple = num_colors // len(colors) - 1
71139
mod = num_colors % len(colors)
72-
73140
colors += multiple * colors
74141
colors += colors[:mod]
75142

76143
return colors
77144

78145

79-
def _is_single_color(color: str) -> bool:
80-
"""Check if ``color`` is a single color.
146+
def _get_colors_from_colormap(
147+
colormap: Union[str, "Colormap"],
148+
num_colors: int,
149+
) -> List[Color]:
150+
"""Get colors from colormap."""
151+
colormap = _get_cmap_instance(colormap)
152+
return [colormap(num) for num in np.linspace(0, 1, num=num_colors)]
153+
154+
155+
def _get_cmap_instance(colormap: Union[str, "Colormap"]) -> "Colormap":
156+
"""Get instance of matplotlib colormap."""
157+
if isinstance(colormap, str):
158+
cmap = colormap
159+
colormap = cm.get_cmap(colormap)
160+
if colormap is None:
161+
raise ValueError(f"Colormap {cmap} is not recognized")
162+
return colormap
163+
164+
165+
def _get_colors_from_color(
166+
color: Union[Color, Collection[Color]],
167+
) -> List[Color]:
168+
"""Get colors from user input color."""
169+
if len(color) == 0:
170+
raise ValueError(f"Invalid color argument: {color}")
171+
172+
if _is_single_color(color):
173+
color = cast(Color, color)
174+
return [color]
175+
176+
color = cast(Collection[Color], color)
177+
return list(_gen_list_of_colors_from_iterable(color))
178+
179+
180+
def _is_single_color(color: Union[Color, Collection[Color]]) -> bool:
181+
"""Check if `color` is a single color, not a sequence of colors.
182+
183+
Single color is of these kinds:
184+
- Named color "red", "C0", "firebrick"
185+
- Alias "g"
186+
- Sequence of floats, such as (0.1, 0.2, 0.3) or (0.1, 0.2, 0.3, 0.4).
187+
188+
See Also
189+
--------
190+
_is_single_string_color
191+
"""
192+
if isinstance(color, str) and _is_single_string_color(color):
193+
# GH #36972
194+
return True
195+
196+
if _is_floats_color(color):
197+
return True
198+
199+
return False
200+
201+
202+
def _gen_list_of_colors_from_iterable(color: Collection[Color]) -> Iterator[Color]:
203+
"""
204+
Yield colors from string of several letters or from collection of colors.
205+
"""
206+
for x in color:
207+
if _is_single_color(x):
208+
yield x
209+
else:
210+
raise ValueError(f"Invalid color {x}")
211+
212+
213+
def _is_floats_color(color: Union[Color, Collection[Color]]) -> bool:
214+
"""Check if color comprises a sequence of floats representing color."""
215+
return bool(
216+
is_list_like(color)
217+
and (len(color) == 3 or len(color) == 4)
218+
and all(isinstance(x, (int, float)) for x in color)
219+
)
220+
221+
222+
def _get_colors_from_color_type(color_type: str, num_colors: int) -> List[Color]:
223+
"""Get colors from user input color type."""
224+
if color_type == "default":
225+
return _get_default_colors(num_colors)
226+
elif color_type == "random":
227+
return _get_random_colors(num_colors)
228+
else:
229+
raise ValueError("color_type must be either 'default' or 'random'")
230+
231+
232+
def _get_default_colors(num_colors: int) -> List[Color]:
233+
"""Get `num_colors` of default colors from matplotlib rc params."""
234+
import matplotlib.pyplot as plt
235+
236+
colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]]
237+
return colors[0:num_colors]
238+
239+
240+
def _get_random_colors(num_colors: int) -> List[Color]:
241+
"""Get `num_colors` of random colors."""
242+
return [_random_color(num) for num in range(num_colors)]
243+
244+
245+
def _random_color(column: int) -> List[float]:
246+
"""Get a random color represented as a list of length 3"""
247+
# GH17525 use common._random_state to avoid resetting the seed
248+
rs = com.random_state(column)
249+
return rs.rand(3).tolist()
250+
251+
252+
def _is_single_string_color(color: Color) -> bool:
253+
"""Check if `color` is a single string color.
81254
82-
Examples of single colors:
255+
Examples of single string colors:
83256
- 'r'
84257
- 'g'
85258
- 'red'
86259
- 'green'
87260
- 'C3'
261+
- 'firebrick'
88262
89263
Parameters
90264
----------
91-
color : string
92-
Color string.
265+
color : Color
266+
Color string or sequence of floats.
93267
94268
Returns
95269
-------
96270
bool
97-
True if ``color`` looks like a valid color.
271+
True if `color` looks like a valid color.
98272
False otherwise.
99273
"""
100274
conv = matplotlib.colors.ColorConverter()

‎pandas/tests/plotting/test_style.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import pytest
2+
3+
from pandas import Series
4+
5+
pytest.importorskip("matplotlib")
6+
from pandas.plotting._matplotlib.style import get_standard_colors
7+
8+
9+
class TestGetStandardColors:
10+
@pytest.mark.parametrize(
11+
"num_colors, expected",
12+
[
13+
(3, ["red", "green", "blue"]),
14+
(5, ["red", "green", "blue", "red", "green"]),
15+
(7, ["red", "green", "blue", "red", "green", "blue", "red"]),
16+
(2, ["red", "green"]),
17+
(1, ["red"]),
18+
],
19+
)
20+
def test_default_colors_named_from_prop_cycle(self, num_colors, expected):
21+
import matplotlib as mpl
22+
from matplotlib.pyplot import cycler
23+
24+
mpl_params = {
25+
"axes.prop_cycle": cycler(color=["red", "green", "blue"]),
26+
}
27+
with mpl.rc_context(rc=mpl_params):
28+
result = get_standard_colors(num_colors=num_colors)
29+
assert result == expected
30+
31+
@pytest.mark.parametrize(
32+
"num_colors, expected",
33+
[
34+
(1, ["b"]),
35+
(3, ["b", "g", "r"]),
36+
(4, ["b", "g", "r", "y"]),
37+
(5, ["b", "g", "r", "y", "b"]),
38+
(7, ["b", "g", "r", "y", "b", "g", "r"]),
39+
],
40+
)
41+
def test_default_colors_named_from_prop_cycle_string(self, num_colors, expected):
42+
import matplotlib as mpl
43+
from matplotlib.pyplot import cycler
44+
45+
mpl_params = {
46+
"axes.prop_cycle": cycler(color="bgry"),
47+
}
48+
with mpl.rc_context(rc=mpl_params):
49+
result = get_standard_colors(num_colors=num_colors)
50+
assert result == expected
51+
52+
@pytest.mark.parametrize(
53+
"num_colors, expected_name",
54+
[
55+
(1, ["C0"]),
56+
(3, ["C0", "C1", "C2"]),
57+
(
58+
12,
59+
[
60+
"C0",
61+
"C1",
62+
"C2",
63+
"C3",
64+
"C4",
65+
"C5",
66+
"C6",
67+
"C7",
68+
"C8",
69+
"C9",
70+
"C0",
71+
"C1",
72+
],
73+
),
74+
],
75+
)
76+
def test_default_colors_named_undefined_prop_cycle(self, num_colors, expected_name):
77+
import matplotlib as mpl
78+
import matplotlib.colors as mcolors
79+
80+
with mpl.rc_context(rc={}):
81+
expected = [mcolors.to_hex(x) for x in expected_name]
82+
result = get_standard_colors(num_colors=num_colors)
83+
assert result == expected
84+
85+
@pytest.mark.parametrize(
86+
"num_colors, expected",
87+
[
88+
(1, ["red", "green", (0.1, 0.2, 0.3)]),
89+
(2, ["red", "green", (0.1, 0.2, 0.3)]),
90+
(3, ["red", "green", (0.1, 0.2, 0.3)]),
91+
(4, ["red", "green", (0.1, 0.2, 0.3), "red"]),
92+
],
93+
)
94+
def test_user_input_color_sequence(self, num_colors, expected):
95+
color = ["red", "green", (0.1, 0.2, 0.3)]
96+
result = get_standard_colors(color=color, num_colors=num_colors)
97+
assert result == expected
98+
99+
@pytest.mark.parametrize(
100+
"num_colors, expected",
101+
[
102+
(1, ["r", "g", "b", "k"]),
103+
(2, ["r", "g", "b", "k"]),
104+
(3, ["r", "g", "b", "k"]),
105+
(4, ["r", "g", "b", "k"]),
106+
(5, ["r", "g", "b", "k", "r"]),
107+
(6, ["r", "g", "b", "k", "r", "g"]),
108+
],
109+
)
110+
def test_user_input_color_string(self, num_colors, expected):
111+
color = "rgbk"
112+
result = get_standard_colors(color=color, num_colors=num_colors)
113+
assert result == expected
114+
115+
@pytest.mark.parametrize(
116+
"num_colors, expected",
117+
[
118+
(1, [(0.1, 0.2, 0.3)]),
119+
(2, [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3)]),
120+
(3, [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3), (0.1, 0.2, 0.3)]),
121+
],
122+
)
123+
def test_user_input_color_floats(self, num_colors, expected):
124+
color = (0.1, 0.2, 0.3)
125+
result = get_standard_colors(color=color, num_colors=num_colors)
126+
assert result == expected
127+
128+
@pytest.mark.parametrize(
129+
"color, num_colors, expected",
130+
[
131+
("Crimson", 1, ["Crimson"]),
132+
("DodgerBlue", 2, ["DodgerBlue", "DodgerBlue"]),
133+
("firebrick", 3, ["firebrick", "firebrick", "firebrick"]),
134+
],
135+
)
136+
def test_user_input_named_color_string(self, color, num_colors, expected):
137+
result = get_standard_colors(color=color, num_colors=num_colors)
138+
assert result == expected
139+
140+
@pytest.mark.parametrize("color", ["", [], (), Series([], dtype="object")])
141+
def test_empty_color_raises(self, color):
142+
with pytest.raises(ValueError, match="Invalid color argument"):
143+
get_standard_colors(color=color, num_colors=1)
144+
145+
@pytest.mark.parametrize(
146+
"color",
147+
[
148+
"bad_color",
149+
("red", "green", "bad_color"),
150+
(0.1,),
151+
(0.1, 0.2),
152+
(0.1, 0.2, 0.3, 0.4, 0.5), # must be either 3 or 4 floats
153+
],
154+
)
155+
def test_bad_color_raises(self, color):
156+
with pytest.raises(ValueError, match="Invalid color"):
157+
get_standard_colors(color=color, num_colors=5)

0 commit comments

Comments
 (0)
Please sign in to comment.