diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index fa1669b1f3343..32602a1ccd24d 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -112,6 +112,7 @@ Other enhancements - :meth:`read_stata` can read Stata 119 dta files. (:issue:`28250`) - Added ``encoding`` argument to :meth:`DataFrame.to_string` for non-ascii text (:issue:`28766`) - Added ``encoding`` argument to :func:`DataFrame.to_html` for non-ascii text (:issue:`28663`) +- :meth:`Styler.background_gradient` now accepts ``vmin`` and ``vmax`` arguments (:issue:`12145`) Build Changes ^^^^^^^^^^^^^ @@ -386,6 +387,7 @@ I/O - Bug in :meth:`DataFrame.read_excel` with ``engine='ods'`` when ``sheet_name`` argument references a non-existent sheet (:issue:`27676`) - Bug in :meth:`pandas.io.formats.style.Styler` formatting for floating values not displaying decimals correctly (:issue:`13257`) - Bug in :meth:`DataFrame.to_html` when using ``formatters=`` and ``max_cols`` together. (:issue:`25955`) +- Bug in :meth:`Styler.background_gradient` not able to work with dtype ``Int64`` (:issue:`28869`) Plotting ^^^^^^^^ diff --git a/pandas/io/formats/style.py b/pandas/io/formats/style.py index 545d6a674411a..9865087a26ae3 100644 --- a/pandas/io/formats/style.py +++ b/pandas/io/formats/style.py @@ -8,6 +8,7 @@ import copy from functools import partial from itertools import product +from typing import Optional from uuid import uuid1 import numpy as np @@ -18,7 +19,6 @@ from pandas.util._decorators import Appender from pandas.core.dtypes.common import is_float, is_string_like -from pandas.core.dtypes.generic import ABCSeries import pandas as pd from pandas.api.types import is_dict_like, is_list_like @@ -963,6 +963,8 @@ def background_gradient( axis=0, subset=None, text_color_threshold=0.408, + vmin: Optional[float] = None, + vmax: Optional[float] = None, ): """ Color the background in a gradient style. @@ -991,6 +993,18 @@ def background_gradient( .. versionadded:: 0.24.0 + vmin : float, optional + Minimum data value that corresponds to colormap minimum value. + When None (default): the minimum value of the data will be used. + + .. versionadded:: 1.0.0 + + vmax : float, optional + Maximum data value that corresponds to colormap maximum value. + When None (default): the maximum value of the data will be used. + + .. versionadded:: 1.0.0 + Returns ------- self : Styler @@ -1017,11 +1031,21 @@ def background_gradient( low=low, high=high, text_color_threshold=text_color_threshold, + vmin=vmin, + vmax=vmax, ) return self @staticmethod - def _background_gradient(s, cmap="PuBu", low=0, high=0, text_color_threshold=0.408): + def _background_gradient( + s, + cmap="PuBu", + low=0, + high=0, + text_color_threshold=0.408, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + ): """ Color background in a range according to the data. """ @@ -1033,14 +1057,14 @@ def _background_gradient(s, cmap="PuBu", low=0, high=0, text_color_threshold=0.4 raise ValueError(msg) with _mpl(Styler.background_gradient) as (plt, colors): - smin = s.values.min() - smax = s.values.max() + smin = np.nanmin(s.to_numpy()) if vmin is None else vmin + smax = np.nanmax(s.to_numpy()) if vmax is None else vmax rng = smax - smin # extend lower / upper bounds, compresses color range norm = colors.Normalize(smin - (rng * low), smax + (rng * high)) # matplotlib colors.Normalize modifies inplace? # https://github.com/matplotlib/matplotlib/issues/5427 - rgbas = plt.cm.get_cmap(cmap)(norm(s.values)) + rgbas = plt.cm.get_cmap(cmap)(norm(s.to_numpy(dtype=float))) def relative_luminance(rgba): """ @@ -1111,12 +1135,8 @@ def _bar(s, align, colors, width=100, vmin=None, vmax=None): Draw bar chart in dataframe cells. """ # Get input value range. - smin = s.min() if vmin is None else vmin - if isinstance(smin, ABCSeries): - smin = smin.min() - smax = s.max() if vmax is None else vmax - if isinstance(smax, ABCSeries): - smax = smax.max() + smin = np.nanmin(s.to_numpy()) if vmin is None else vmin + smax = np.nanmax(s.to_numpy()) if vmax is None else vmax if align == "mid": smin = min(0, smin) smax = max(0, smax) @@ -1125,7 +1145,7 @@ def _bar(s, align, colors, width=100, vmin=None, vmax=None): smax = max(abs(smin), abs(smax)) smin = -smax # Transform to percent-range of linear-gradient - normed = width * (s.values - smin) / (smax - smin + 1e-12) + normed = width * (s.to_numpy(dtype=float) - smin) / (smax - smin + 1e-12) zero = -width * smin / (smax - smin + 1e-12) def css_bar(start, end, color): @@ -1304,17 +1324,15 @@ def _highlight_extrema(data, color="yellow", max_=True): Highlight the min or max in a Series or DataFrame. """ attr = "background-color: {0}".format(color) + + if max_: + extrema = data == np.nanmax(data.to_numpy()) + else: + extrema = data == np.nanmin(data.to_numpy()) + if data.ndim == 1: # Series from .apply - if max_: - extrema = data == data.max() - else: - extrema = data == data.min() return [attr if v else "" for v in extrema] else: # DataFrame from .tee - if max_: - extrema = data == data.max().max() - else: - extrema = data == data.min().min() return pd.DataFrame( np.where(extrema, attr, ""), index=data.index, columns=data.columns ) diff --git a/pandas/tests/io/formats/test_style.py b/pandas/tests/io/formats/test_style.py index 0f1402d7da389..fa725ccae66f9 100644 --- a/pandas/tests/io/formats/test_style.py +++ b/pandas/tests/io/formats/test_style.py @@ -1648,6 +1648,23 @@ def test_background_gradient_axis(self): assert result[(1, 0)] == mid assert result[(1, 1)] == high + def test_background_gradient_vmin_vmax(self): + # GH 12145 + df = pd.DataFrame(range(5)) + ctx = df.style.background_gradient(vmin=1, vmax=3)._compute().ctx + assert ctx[(0, 0)] == ctx[(1, 0)] + assert ctx[(4, 0)] == ctx[(3, 0)] + + def test_background_gradient_int64(self): + # GH 28869 + df1 = pd.Series(range(3)).to_frame() + df2 = pd.Series(range(3), dtype="Int64").to_frame() + ctx1 = df1.style.background_gradient()._compute().ctx + ctx2 = df2.style.background_gradient()._compute().ctx + assert ctx2[(0, 0)] == ctx1[(0, 0)] + assert ctx2[(1, 0)] == ctx1[(1, 0)] + assert ctx2[(2, 0)] == ctx1[(2, 0)] + def test_block_names(): # catch accidental removal of a block