Skip to content

Commit e49957a

Browse files
committed
refactor args
1 parent a385f75 commit e49957a

File tree

2 files changed

+54
-34
lines changed

2 files changed

+54
-34
lines changed

pandas/io/formats/style.py

+46-29
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161

6262
try:
63-
from matplotlib import colors
63+
import matplotlib as mpl
6464
import matplotlib.pyplot as plt
6565

6666
has_mpl = True
@@ -72,7 +72,7 @@
7272
@contextmanager
7373
def _mpl(func: Callable):
7474
if has_mpl:
75-
yield plt, colors
75+
yield plt, mpl
7676
else:
7777
raise ImportError(no_mpl_message.format(func.__name__))
7878

@@ -2608,7 +2608,8 @@ def bar(
26082608
subset: Subset | None = None,
26092609
axis: Axis | None = 0,
26102610
*,
2611-
color="#d65f5f",
2611+
color: str | list | tuple | None = None,
2612+
cmap: Any | None = None,
26122613
width: float = 100,
26132614
height: float = 100,
26142615
align: str | float | int | Callable = "mid",
@@ -2631,12 +2632,16 @@ def bar(
26312632
Apply to each column (``axis=0`` or ``'index'``), to each row
26322633
(``axis=1`` or ``'columns'``), or to the entire DataFrame at once
26332634
with ``axis=None``.
2634-
color : str, 2-tuple/list, matplotlib Colormap
2635+
color : str or 2-tuple/list
26352636
If a str is passed, the color is the same for both
26362637
negative and positive numbers. If 2-tuple/list is used, the
26372638
first element is the color_negative and the second is the
2638-
color_positive (eg: ['#d65f5f', '#5fba7d']). Alternatively, assigns
2639-
colors from a Colormap based on the datavalues.
2639+
color_positive (eg: ['#d65f5f', '#5fba7d']).
2640+
cmap : str, matplotlib.cm.ColorMap
2641+
A string name of a matplotlib Colormap, or a Colormap object. Cannot be
2642+
used together with ``color``.
2643+
2644+
.. versionadded:: 1.4.0
26402645
width : float, default 100
26412646
The percentage of the cell, measured from the left, in which to draw the
26422647
bars, in [0, 100].
@@ -2686,6 +2691,19 @@ def bar(
26862691
`Table Visualization <../../user_guide/style.ipynb>`_ gives
26872692
a number of examples for different settings and color coordination.
26882693
"""
2694+
if color is None and cmap is None:
2695+
color = "#d65f5f"
2696+
elif color is not None and cmap is not None:
2697+
raise ValueError("`color` and `cmap` cannot both be given")
2698+
elif color is not None:
2699+
if (isinstance(color, (list, tuple)) and len(color) > 2) or not isinstance(
2700+
color, (str, list, tuple)
2701+
):
2702+
raise ValueError(
2703+
"`color` must be string or list or tuple of 2 strings,"
2704+
"(eg: color=['#d65f5f', '#5fba7d'])"
2705+
)
2706+
26892707
if not (0 <= width <= 100):
26902708
raise ValueError(f"`width` must be a value in [0, 100], got {width}")
26912709
elif not (0 <= height <= 100):
@@ -2700,6 +2718,7 @@ def bar(
27002718
axis=axis,
27012719
align=align,
27022720
colors=color,
2721+
cmap=cmap,
27032722
width=width / 100,
27042723
height=height / 100,
27052724
vmin=vmin,
@@ -3256,12 +3275,12 @@ def _background_gradient(
32563275
else: # else validate gmap against the underlying data
32573276
gmap = _validate_apply_axis_arg(gmap, "gmap", float, data)
32583277

3259-
with _mpl(Styler.background_gradient) as (plt, colors):
3278+
with _mpl(Styler.background_gradient) as (plt, mpl):
32603279
smin = np.nanmin(gmap) if vmin is None else vmin
32613280
smax = np.nanmax(gmap) if vmax is None else vmax
32623281
rng = smax - smin
32633282
# extend lower / upper bounds, compresses color range
3264-
norm = colors.Normalize(smin - (rng * low), smax + (rng * high))
3283+
norm = mpl.colors.Normalize(smin - (rng * low), smax + (rng * high))
32653284
rgbas = plt.cm.get_cmap(cmap)(norm(gmap))
32663285

32673286
def relative_luminance(rgba) -> float:
@@ -3290,9 +3309,11 @@ def css(rgba, text_only) -> str:
32903309
if not text_only:
32913310
dark = relative_luminance(rgba) < text_color_threshold
32923311
text_color = "#f1f1f1" if dark else "#000000"
3293-
return f"background-color: {colors.rgb2hex(rgba)};color: {text_color};"
3312+
return (
3313+
f"background-color: {mpl.colors.rgb2hex(rgba)};color: {text_color};"
3314+
)
32943315
else:
3295-
return f"color: {colors.rgb2hex(rgba)};"
3316+
return f"color: {mpl.colors.rgb2hex(rgba)};"
32963317

32973318
if data.ndim == 1:
32983319
return [css(rgba, text_only) for rgba in rgbas]
@@ -3365,7 +3386,8 @@ def _highlight_value(data: FrameOrSeries, op: str, props: str) -> np.ndarray:
33653386
def _bar(
33663387
data: FrameOrSeries,
33673388
align: str | float | int | Callable,
3368-
colors: Any,
3389+
colors: str | list | tuple,
3390+
cmap: Any,
33693391
width: float,
33703392
height: float,
33713393
vmin: float | None,
@@ -3427,7 +3449,7 @@ def css_bar(start: float, end: float, color: str) -> str:
34273449
cell_css += f" {color} {end*100:.1f}%, transparent {end*100:.1f}%)"
34283450
return cell_css
34293451

3430-
def css_calc(x, left: float, right: float, align: str, color: list | str):
3452+
def css_calc(x, left: float, right: float, align: str, color: str | list | tuple):
34313453
"""
34323454
Return the correct CSS for bar placement based on calculated values.
34333455
@@ -3458,7 +3480,7 @@ def css_calc(x, left: float, right: float, align: str, color: list | str):
34583480
if pd.isna(x):
34593481
return base_css
34603482

3461-
if isinstance(color, list):
3483+
if isinstance(color, (list, tuple)):
34623484
color = color[0] if x < 0 else color[1]
34633485
assert isinstance(color, str) # mypy redefinition
34643486

@@ -3525,25 +3547,20 @@ def css_calc(x, left: float, right: float, align: str, color: list | str):
35253547
)
35263548

35273549
rgbas = None
3528-
if not isinstance(colors, (list, str)):
3550+
if cmap is not None:
35293551
# use the matplotlib colormap input
3530-
with _mpl(Styler.bar) as (plt, mpl_colors):
3531-
norm = mpl_colors.Normalize(left, right)
3532-
if not isinstance(colors, mpl_colors.Colormap):
3533-
raise ValueError(
3534-
"`colors` must be a matplotlib Colormap if not string "
3535-
"or list of strings."
3536-
)
3537-
rgbas = colors(norm(values))
3552+
with _mpl(Styler.bar) as (plt, mpl):
3553+
cmap = (
3554+
mpl.cm.get_cmap(cmap)
3555+
if isinstance(cmap, str)
3556+
else cmap # assumed to be a Colormap instance as documented
3557+
)
3558+
norm = mpl.colors.Normalize(left, right)
3559+
rgbas = cmap(norm(values))
35383560
if data.ndim == 1:
3539-
rgbas = [mpl_colors.rgb2hex(rgba) for rgba in rgbas]
3561+
rgbas = [mpl.colors.rgb2hex(rgba) for rgba in rgbas]
35403562
else:
3541-
rgbas = [[mpl_colors.rgb2hex(rgba) for rgba in row] for row in rgbas]
3542-
elif isinstance(colors, list) and len(colors) > 2:
3543-
raise ValueError(
3544-
"`color` must be string, list-like of 2 strings, or matplotlib Colormap "
3545-
"(eg: color=['#d65f5f', '#5fba7d'])"
3546-
)
3563+
rgbas = [[mpl.colors.rgb2hex(rgba) for rgba in row] for row in rgbas]
35473564

35483565
assert isinstance(align, str) # mypy: should now be in [left, right, mid, zero]
35493566
if data.ndim == 1:

pandas/tests/io/formats/style/test_matplotlib.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,10 @@ def test_background_gradient_gmap_wrong_series(styler_blank):
260260
styler_blank.background_gradient(gmap=gmap, axis=None)._compute()
261261

262262

263-
def test_bar_colormap():
263+
@pytest.mark.parametrize("cmap", ["PuBu", mpl.cm.get_cmap("PuBu")])
264+
def test_bar_colormap(cmap):
264265
data = DataFrame([[1, 2], [3, 4]])
265-
ctx = data.style.bar(color=mpl.cm.get_cmap("PuBu"), axis=None)._compute().ctx
266+
ctx = data.style.bar(cmap=cmap, axis=None)._compute().ctx
266267
pubu_colors = {
267268
(0, 0): "#d0d1e6",
268269
(1, 0): "#056faf",
@@ -274,10 +275,12 @@ def test_bar_colormap():
274275

275276

276277
def test_bar_color_raises(df):
277-
msg = "`colors` must be a matplotlib Colormap if not string or list of strings"
278+
msg = "`color` must be string or list or tuple of 2 strings"
278279
with pytest.raises(ValueError, match=msg):
279280
df.style.bar(color={"a", "b"}).to_html()
280-
281-
msg = "`color` must be string, list-like of 2 strings, or matplotlib Colormap"
282281
with pytest.raises(ValueError, match=msg):
283282
df.style.bar(color=["a", "b", "c"]).to_html()
283+
284+
msg = "`color` and `cmap` cannot both be given"
285+
with pytest.raises(ValueError, match=msg):
286+
df.style.bar(color="something", cmap="something else").to_html()

0 commit comments

Comments
 (0)