Skip to content

CLN: clean color selection in _matplotlib/style #37203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0f1b99a
REF: extract functions
ivanovmg Oct 17, 2020
901453a
CLN: remove try/except/ZeroDivisionError
ivanovmg Oct 17, 2020
201b25f
REF: drop unnecesasry if statement
ivanovmg Oct 17, 2020
8e13df5
CLN: simplify logic
ivanovmg Oct 17, 2020
37a820d
DOC: add short docstrings
ivanovmg Oct 17, 2020
3883a13
CLN: simplify logic further
ivanovmg Oct 17, 2020
f93743c
TYP: add type annotations
ivanovmg Oct 17, 2020
b4c3267
REF: more explicitly handle string color
ivanovmg Oct 17, 2020
6af1543
FIX: fix mpl registry reset
ivanovmg Oct 17, 2020
31125f7
TYP: fix typing in _get_colors_from_color
ivanovmg Oct 18, 2020
45647a4
CLN: eliminate use of legacy "axes.color_cycle"
ivanovmg Oct 18, 2020
393ae46
REF: extract generator function to simplify logic
ivanovmg Oct 18, 2020
fe66213
TST: add tests for get_standard_colors
ivanovmg Oct 18, 2020
1626108
CLN: drop list comprehension for generator expr
ivanovmg Oct 18, 2020
79b0f08
TYP: annotate get_standard_colors
ivanovmg Oct 18, 2020
f513bdb
DEP: add testing dependency (cycler)
ivanovmg Oct 18, 2020
76f7663
Remove test_style temporary
ivanovmg Oct 18, 2020
0f0f4bc
BLD: remove cycler from dependencies temporary
ivanovmg Oct 18, 2020
b8daf79
Revert "Remove test_style temporary"
ivanovmg Oct 19, 2020
37734e8
REF: import cycler from matplotlib.pyplot
ivanovmg Oct 19, 2020
765836f
TST: mark test skip if no mpl
ivanovmg Oct 19, 2020
4479e37
Merge branch 'master' into cleanup/matplotlib-style
ivanovmg Oct 19, 2020
dd9efd7
Merge branch 'master' into cleanup/matplotlib-style
ivanovmg Oct 20, 2020
f0ea701
REF: use pytest.importorskip
ivanovmg Oct 23, 2020
dedd0dd
REF: extract new method _is_single_color
ivanovmg Oct 30, 2020
b369834
DOC: add/update docstrings
ivanovmg Nov 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 144 additions & 51 deletions pandas/plotting/_matplotlib/style.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
# being a bit too dynamic
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
cast,
)
import warnings

import matplotlib.cm as cm
Expand All @@ -9,74 +19,156 @@

import pandas.core.common as com

if TYPE_CHECKING:
from matplotlib.colors import Colormap


Color = Union[str, Sequence[float]]


def get_standard_colors(
num_colors: int, colormap=None, color_type: str = "default", color=None
num_colors: int,
colormap: Optional["Colormap"] = None,
color_type: str = "default",
color: Optional[Union[Dict[str, Color], Color, Collection[Color]]] = None,
):
import matplotlib.pyplot as plt

if isinstance(color, dict):
return color

colors = _get_colors(
color=color,
colormap=colormap,
color_type=color_type,
num_colors=num_colors,
)

return _cycle_colors(colors, num_colors=num_colors)


def _get_colors(
*,
color: Optional[Union[Color, Collection[Color]]],
colormap: Optional[Union[str, "Colormap"]],
color_type: str,
num_colors: int,
) -> List[Color]:
"""Get colors from user input."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a doc-string.

also the current summary is not very descriptive, nor is the function name .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it into _derive_colors. I am not sure if you would consider it a better alternative. I thought of having a more verbose name (like _derive_colors_from_cmap_color_and_type, but that does not seem reasonable to me, although I added the explanation in the docstring.
Anyway derive seems a better verb, implying some logic underneath.

if color is None and colormap is not None:
if isinstance(colormap, str):
cmap = colormap
colormap = cm.get_cmap(colormap)
if colormap is None:
raise ValueError(f"Colormap {cmap} is not recognized")
colors = [colormap(num) for num in np.linspace(0, 1, num=num_colors)]
return _get_colors_from_colormap(colormap, num_colors=num_colors)
elif color is not None:
if colormap is not None:
warnings.warn(
"'color' and 'colormap' cannot be used simultaneously. Using 'color'"
)
colors = (
list(color)
if is_list_like(color) and not isinstance(color, dict)
else color
)
return _get_colors_from_color(color)
else:
if color_type == "default":
# need to call list() on the result to copy so we don't
# modify the global rcParams below
try:
colors = [c["color"] for c in list(plt.rcParams["axes.prop_cycle"])]
except KeyError:
colors = list(plt.rcParams.get("axes.color_cycle", list("bgrcmyk")))
if isinstance(colors, str):
colors = list(colors)

colors = colors[0:num_colors]
elif color_type == "random":

def random_color(column):
""" Returns a random color represented as a list of length 3"""
# GH17525 use common._random_state to avoid resetting the seed
rs = com.random_state(column)
return rs.rand(3).tolist()

colors = [random_color(num) for num in range(num_colors)]
else:
raise ValueError("color_type must be either 'default' or 'random'")
return _get_colors_from_color_type(color_type, num_colors=num_colors)

if isinstance(colors, str) and _is_single_color(colors):
# GH #36972
colors = [colors]

# Append more colors by cycling if there is not enough color.
# Extra colors will be ignored by matplotlib if there are more colors
# than needed and nothing needs to be done here.
def _cycle_colors(colors: List[Color], num_colors: int) -> List[Color]:
"""Append more colors by cycling if there is not enough color.

Extra colors will be ignored by matplotlib if there are more colors
than needed and nothing needs to be done here.
"""
if len(colors) < num_colors:
try:
multiple = num_colors // len(colors) - 1
except ZeroDivisionError:
raise ValueError("Invalid color argument: ''")
multiple = num_colors // len(colors) - 1
mod = num_colors % len(colors)

colors += multiple * colors
colors += colors[:mod]

return colors


def _is_single_color(color: str) -> bool:
def _get_colors_from_colormap(
colormap: Union[str, "Colormap"],
num_colors: int,
) -> List[Color]:
"""Get colors from colormap."""
colormap = _get_cmap_instance(colormap)
return [colormap(num) for num in np.linspace(0, 1, num=num_colors)]


def _get_cmap_instance(colormap: Union[str, "Colormap"]) -> "Colormap":
"""Get instance of matplotlib colormap."""
if isinstance(colormap, str):
cmap = colormap
colormap = cm.get_cmap(colormap)
if colormap is None:
raise ValueError(f"Colormap {cmap} is not recognized")
return colormap


def _get_colors_from_color(
color: Union[Color, Collection[Color]],
) -> List[Color]:
"""Get colors from user input color."""
if len(color) == 0:
raise ValueError(f"Invalid color argument: {color}")

if isinstance(color, str) and _is_single_color(color):
# GH #36972
return [color]

if _is_floats_color(color):
color = cast(Sequence[float], color)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cast seems unnecessary - is there an error this solves?

Copy link
Member Author

@ivanovmg ivanovmg Oct 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, cast is necessary.
At this point color is of type Union[Color, Collection[Color]]. And for some reason _is_floats_color check does not filter out Collection[Color]. So, instead of ignoring, I added cast here.

mypy error if removing cast.

pandas\plotting\_matplotlib\style.py:114: error: List item 0 has incompatible type "Union[Union[str, Sequence[float]], Collection[Union[str, Sequence[float]]]]"; expected "Union[str, Sequence[float]]"  [list-item]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions can't narrow types yet in mypy, but regardless this is pretty confusing. Can you try to refactor this to make things clearer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know how to make it even clearer.
The logic is the following.

  1. Start with color being maybe
    a. a single string color
    b. or multiple string color (like 'rgbk')
    c. or single float color (0.1, 0.2, 0.3)
    d. or multiple float colors
  2. Address option a.
  3. Address option b.
  4. Address options c and d.

Is the logic confusing or the two cast statements are?

Copy link
Member Author

@ivanovmg ivanovmg Oct 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WillAyd, I updated the logic a little bit.

  1. Address one color (either string or sequence of floats)
  2. Address collection of colors

Casting is still there as there are no other options to make mypy happy.

Please let me know what you think about this implementation.

return [color]

color = cast(Collection[Color], color)
return list(_gen_list_of_colors_from_iterable(color))


def _gen_list_of_colors_from_iterable(color: Collection[Color]) -> Iterator[Color]:
"""
Yield colors from string of several letters or from collection of colors.
"""
for x in color:
if _is_single_color(x):
yield x
else:
raise ValueError(f"Invalid color {x}")


def _is_floats_color(color: Union[Color, Collection[Color]]) -> bool:
"""Check if color comprises a sequence of floats representing color."""
return bool(
is_list_like(color)
and (len(color) == 3 or len(color) == 4)
and all(isinstance(x, float) for x in color)
)


def _get_colors_from_color_type(color_type: str, num_colors: int) -> List[Color]:
"""Get colors from user input color type."""
if color_type == "default":
return _get_default_colors(num_colors)
elif color_type == "random":
return _get_random_colors(num_colors)
else:
raise ValueError("color_type must be either 'default' or 'random'")


def _get_default_colors(num_colors: int) -> List[Color]:
"""Get ``num_colors`` of default colors from matplotlib rc params."""
import matplotlib.pyplot as plt

colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]]
return colors[0:num_colors]


def _get_random_colors(num_colors: int) -> List[Color]:
"""Get ``num_colors`` of random colors."""
return [_random_color(num) for num in range(num_colors)]


def _random_color(column: int) -> List[float]:
"""Get a random color represented as a list of length 3"""
# GH17525 use common._random_state to avoid resetting the seed
rs = com.random_state(column)
return rs.rand(3).tolist()


def _is_single_color(color: Color) -> bool:
"""Check if ``color`` is a single color.

Examples of single colors:
Expand All @@ -85,11 +177,12 @@ def _is_single_color(color: str) -> bool:
- 'red'
- 'green'
- 'C3'
- 'firebrick'

Parameters
----------
color : string
Color string.
color : Color
Color string or sequence of floats.

Returns
-------
Expand Down
162 changes: 162 additions & 0 deletions pandas/tests/plotting/test_style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from contextlib import suppress

import pytest

import pandas.util._test_decorators as td

from pandas import Series

with suppress(ImportError):
from pandas.plotting._matplotlib.style import get_standard_colors


@td.skip_if_no_mpl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just use pytest.importorskip see pandas\tests\plotting\test_converter.py or other test modules testing optional dependencies. e.g. pytables

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did that. Please check if this is what you suggested.

class TestGetStandardColors:
@pytest.mark.parametrize(
"num_colors, expected",
[
(3, ["red", "green", "blue"]),
(5, ["red", "green", "blue", "red", "green"]),
(7, ["red", "green", "blue", "red", "green", "blue", "red"]),
(2, ["red", "green"]),
(1, ["red"]),
],
)
def test_default_colors_named_from_prop_cycle(self, num_colors, expected):
import matplotlib as mpl
from matplotlib.pyplot import cycler

mpl_params = {
"axes.prop_cycle": cycler(color=["red", "green", "blue"]),
}
with mpl.rc_context(rc=mpl_params):
result = get_standard_colors(num_colors=num_colors)
assert result == expected

@pytest.mark.parametrize(
"num_colors, expected",
[
(1, ["b"]),
(3, ["b", "g", "r"]),
(4, ["b", "g", "r", "y"]),
(5, ["b", "g", "r", "y", "b"]),
(7, ["b", "g", "r", "y", "b", "g", "r"]),
],
)
def test_default_colors_named_from_prop_cycle_string(self, num_colors, expected):
Copy link
Member

@charlesdong1991 charlesdong1991 Oct 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you could combine this one with the test above, and add cycler_color to parameterize since those two tests look almost identical and test basically the same except how colors are defined

import matplotlib as mpl
from matplotlib.pyplot import cycler

mpl_params = {
"axes.prop_cycle": cycler(color="bgry"),
}
with mpl.rc_context(rc=mpl_params):
result = get_standard_colors(num_colors=num_colors)
assert result == expected

@pytest.mark.parametrize(
"num_colors, expected_name",
[
(1, ["C0"]),
(3, ["C0", "C1", "C2"]),
(
12,
[
"C0",
"C1",
"C2",
"C3",
"C4",
"C5",
"C6",
"C7",
"C8",
"C9",
"C0",
"C1",
],
),
],
)
def test_default_colors_named_undefined_prop_cycle(self, num_colors, expected_name):
import matplotlib as mpl
import matplotlib.colors as mcolors

with mpl.rc_context(rc={}):
expected = [mcolors.to_hex(x) for x in expected_name]
result = get_standard_colors(num_colors=num_colors)
assert result == expected

@pytest.mark.parametrize(
"num_colors, expected",
[
(1, ["red", "green", (0.1, 0.2, 0.3)]),
(2, ["red", "green", (0.1, 0.2, 0.3)]),
(3, ["red", "green", (0.1, 0.2, 0.3)]),
(4, ["red", "green", (0.1, 0.2, 0.3), "red"]),
],
)
def test_user_input_color_sequence(self, num_colors, expected):
color = ["red", "green", (0.1, 0.2, 0.3)]
result = get_standard_colors(color=color, num_colors=num_colors)
assert result == expected

@pytest.mark.parametrize(
"num_colors, expected",
[
(1, ["r", "g", "b", "k"]),
(2, ["r", "g", "b", "k"]),
(3, ["r", "g", "b", "k"]),
(4, ["r", "g", "b", "k"]),
(5, ["r", "g", "b", "k", "r"]),
(6, ["r", "g", "b", "k", "r", "g"]),
Comment on lines +102 to +107
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's nicer to just test the case where num_color </>/= color, so 3 parametrizations is enough. no need to change here though ^^

],
)
def test_user_input_color_string(self, num_colors, expected):
color = "rgbk"
result = get_standard_colors(color=color, num_colors=num_colors)
assert result == expected

@pytest.mark.parametrize(
"num_colors, expected",
[
(1, [(0.1, 0.2, 0.3)]),
(2, [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3)]),
(3, [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3), (0.1, 0.2, 0.3)]),
],
)
def test_user_input_color_floats(self, num_colors, expected):
color = (0.1, 0.2, 0.3)
result = get_standard_colors(color=color, num_colors=num_colors)
assert result == expected

@pytest.mark.parametrize(
"color, num_colors, expected",
[
("Crimson", 1, ["Crimson"]),
("DodgerBlue", 2, ["DodgerBlue", "DodgerBlue"]),
("firebrick", 3, ["firebrick", "firebrick", "firebrick"]),
],
)
def test_user_input_named_color_string(self, color, num_colors, expected):
result = get_standard_colors(color=color, num_colors=num_colors)
assert result == expected

@pytest.mark.parametrize("color", ["", [], (), Series([], dtype="object")])
def test_empty_color_raises(self, color):
with pytest.raises(ValueError, match="Invalid color argument"):
get_standard_colors(color=color, num_colors=1)

@pytest.mark.parametrize(
"color",
[
"bad_color",
("red", "green", "bad_color"),
(0.1,),
(0.1, 0.2),
(0.1, 0.2, 0.3, 0.4, 0.5), # must be either 3 or 4 floats
],
)
def test_bad_color_raises(self, color):
with pytest.raises(ValueError, match="Invalid color"):
get_standard_colors(color=color, num_colors=5)