Skip to content

Commit 3d86e4b

Browse files
immaxchenproost
authored andcommitted
ENH: Styler.background_gradient to accept vmin vmax and dtype Int64 (pandas-dev#29245)
* ENH: Styler.background_gradient to accept vmin vmax and dtype Int64 Resolve pandas-dev#12145 and pandas-dev#28869 For `vmin` and `vmax` use the same implementation in `Styler.bar` For dtype `Int64` issue, deprecated `.values` and use `.to_numpy` instead Here explicitly assign the dtype to float since we are doing normalize
1 parent fd5d6d1 commit 3d86e4b

File tree

3 files changed

+57
-20
lines changed

3 files changed

+57
-20
lines changed

doc/source/whatsnew/v1.0.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Other enhancements
112112
- :meth:`read_stata` can read Stata 119 dta files. (:issue:`28250`)
113113
- Added ``encoding`` argument to :meth:`DataFrame.to_string` for non-ascii text (:issue:`28766`)
114114
- Added ``encoding`` argument to :func:`DataFrame.to_html` for non-ascii text (:issue:`28663`)
115+
- :meth:`Styler.background_gradient` now accepts ``vmin`` and ``vmax`` arguments (:issue:`12145`)
115116

116117
Build Changes
117118
^^^^^^^^^^^^^
@@ -391,6 +392,7 @@ I/O
391392
- Bug in :meth:`DataFrame.read_excel` with ``engine='ods'`` when ``sheet_name`` argument references a non-existent sheet (:issue:`27676`)
392393
- Bug in :meth:`pandas.io.formats.style.Styler` formatting for floating values not displaying decimals correctly (:issue:`13257`)
393394
- Bug in :meth:`DataFrame.to_html` when using ``formatters=<list>`` and ``max_cols`` together. (:issue:`25955`)
395+
- Bug in :meth:`Styler.background_gradient` not able to work with dtype ``Int64`` (:issue:`28869`)
394396

395397
Plotting
396398
^^^^^^^^

pandas/io/formats/style.py

+38-20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import copy
99
from functools import partial
1010
from itertools import product
11+
from typing import Optional
1112
from uuid import uuid1
1213

1314
import numpy as np
@@ -18,7 +19,6 @@
1819
from pandas.util._decorators import Appender
1920

2021
from pandas.core.dtypes.common import is_float, is_string_like
21-
from pandas.core.dtypes.generic import ABCSeries
2222

2323
import pandas as pd
2424
from pandas.api.types import is_dict_like, is_list_like
@@ -963,6 +963,8 @@ def background_gradient(
963963
axis=0,
964964
subset=None,
965965
text_color_threshold=0.408,
966+
vmin: Optional[float] = None,
967+
vmax: Optional[float] = None,
966968
):
967969
"""
968970
Color the background in a gradient style.
@@ -991,6 +993,18 @@ def background_gradient(
991993
992994
.. versionadded:: 0.24.0
993995
996+
vmin : float, optional
997+
Minimum data value that corresponds to colormap minimum value.
998+
When None (default): the minimum value of the data will be used.
999+
1000+
.. versionadded:: 1.0.0
1001+
1002+
vmax : float, optional
1003+
Maximum data value that corresponds to colormap maximum value.
1004+
When None (default): the maximum value of the data will be used.
1005+
1006+
.. versionadded:: 1.0.0
1007+
9941008
Returns
9951009
-------
9961010
self : Styler
@@ -1017,11 +1031,21 @@ def background_gradient(
10171031
low=low,
10181032
high=high,
10191033
text_color_threshold=text_color_threshold,
1034+
vmin=vmin,
1035+
vmax=vmax,
10201036
)
10211037
return self
10221038

10231039
@staticmethod
1024-
def _background_gradient(s, cmap="PuBu", low=0, high=0, text_color_threshold=0.408):
1040+
def _background_gradient(
1041+
s,
1042+
cmap="PuBu",
1043+
low=0,
1044+
high=0,
1045+
text_color_threshold=0.408,
1046+
vmin: Optional[float] = None,
1047+
vmax: Optional[float] = None,
1048+
):
10251049
"""
10261050
Color background in a range according to the data.
10271051
"""
@@ -1033,14 +1057,14 @@ def _background_gradient(s, cmap="PuBu", low=0, high=0, text_color_threshold=0.4
10331057
raise ValueError(msg)
10341058

10351059
with _mpl(Styler.background_gradient) as (plt, colors):
1036-
smin = s.values.min()
1037-
smax = s.values.max()
1060+
smin = np.nanmin(s.to_numpy()) if vmin is None else vmin
1061+
smax = np.nanmax(s.to_numpy()) if vmax is None else vmax
10381062
rng = smax - smin
10391063
# extend lower / upper bounds, compresses color range
10401064
norm = colors.Normalize(smin - (rng * low), smax + (rng * high))
10411065
# matplotlib colors.Normalize modifies inplace?
10421066
# https://github.com/matplotlib/matplotlib/issues/5427
1043-
rgbas = plt.cm.get_cmap(cmap)(norm(s.values))
1067+
rgbas = plt.cm.get_cmap(cmap)(norm(s.to_numpy(dtype=float)))
10441068

10451069
def relative_luminance(rgba):
10461070
"""
@@ -1111,12 +1135,8 @@ def _bar(s, align, colors, width=100, vmin=None, vmax=None):
11111135
Draw bar chart in dataframe cells.
11121136
"""
11131137
# Get input value range.
1114-
smin = s.min() if vmin is None else vmin
1115-
if isinstance(smin, ABCSeries):
1116-
smin = smin.min()
1117-
smax = s.max() if vmax is None else vmax
1118-
if isinstance(smax, ABCSeries):
1119-
smax = smax.max()
1138+
smin = np.nanmin(s.to_numpy()) if vmin is None else vmin
1139+
smax = np.nanmax(s.to_numpy()) if vmax is None else vmax
11201140
if align == "mid":
11211141
smin = min(0, smin)
11221142
smax = max(0, smax)
@@ -1125,7 +1145,7 @@ def _bar(s, align, colors, width=100, vmin=None, vmax=None):
11251145
smax = max(abs(smin), abs(smax))
11261146
smin = -smax
11271147
# Transform to percent-range of linear-gradient
1128-
normed = width * (s.values - smin) / (smax - smin + 1e-12)
1148+
normed = width * (s.to_numpy(dtype=float) - smin) / (smax - smin + 1e-12)
11291149
zero = -width * smin / (smax - smin + 1e-12)
11301150

11311151
def css_bar(start, end, color):
@@ -1304,17 +1324,15 @@ def _highlight_extrema(data, color="yellow", max_=True):
13041324
Highlight the min or max in a Series or DataFrame.
13051325
"""
13061326
attr = "background-color: {0}".format(color)
1327+
1328+
if max_:
1329+
extrema = data == np.nanmax(data.to_numpy())
1330+
else:
1331+
extrema = data == np.nanmin(data.to_numpy())
1332+
13071333
if data.ndim == 1: # Series from .apply
1308-
if max_:
1309-
extrema = data == data.max()
1310-
else:
1311-
extrema = data == data.min()
13121334
return [attr if v else "" for v in extrema]
13131335
else: # DataFrame from .tee
1314-
if max_:
1315-
extrema = data == data.max().max()
1316-
else:
1317-
extrema = data == data.min().min()
13181336
return pd.DataFrame(
13191337
np.where(extrema, attr, ""), index=data.index, columns=data.columns
13201338
)

pandas/tests/io/formats/test_style.py

+17
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,23 @@ def test_background_gradient_axis(self):
16481648
assert result[(1, 0)] == mid
16491649
assert result[(1, 1)] == high
16501650

1651+
def test_background_gradient_vmin_vmax(self):
1652+
# GH 12145
1653+
df = pd.DataFrame(range(5))
1654+
ctx = df.style.background_gradient(vmin=1, vmax=3)._compute().ctx
1655+
assert ctx[(0, 0)] == ctx[(1, 0)]
1656+
assert ctx[(4, 0)] == ctx[(3, 0)]
1657+
1658+
def test_background_gradient_int64(self):
1659+
# GH 28869
1660+
df1 = pd.Series(range(3)).to_frame()
1661+
df2 = pd.Series(range(3), dtype="Int64").to_frame()
1662+
ctx1 = df1.style.background_gradient()._compute().ctx
1663+
ctx2 = df2.style.background_gradient()._compute().ctx
1664+
assert ctx2[(0, 0)] == ctx1[(0, 0)]
1665+
assert ctx2[(1, 0)] == ctx1[(1, 0)]
1666+
assert ctx2[(2, 0)] == ctx1[(2, 0)]
1667+
16511668

16521669
def test_block_names():
16531670
# catch accidental removal of a block

0 commit comments

Comments
 (0)