Skip to content

ENH: add set_td_classes method for CSS class addition to data cells #36159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 13, 2020
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Other enhancements
- :meth:`DataFrame.applymap` now supports ``na_action`` (:issue:`23803`)
- :class:`Index` with object dtype supports division and multiplication (:issue:`34160`)
- :meth:`DataFrame.explode` and :meth:`Series.explode` now support exploding of sets (:issue:`35614`)
-
- `Styler` now allows direct CSS class name addition to individual data cells (:issue:`36159`)

.. _whatsnew_120.api_breaking.python:

Expand Down
69 changes: 68 additions & 1 deletion pandas/io/formats/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def __init__(
self.cell_ids = cell_ids
self.na_rep = na_rep

self.cell_context: Dict[str, Any] = {}

# display_funcs maps (row, col) -> formatting function

def default_display_func(x):
Expand Down Expand Up @@ -262,7 +264,7 @@ def format_attr(pair):
idx_lengths = _get_level_lengths(self.index)
col_lengths = _get_level_lengths(self.columns, hidden_columns)

cell_context = dict()
cell_context = self.cell_context

n_rlvls = self.data.index.nlevels
n_clvls = self.data.columns.nlevels
Expand Down Expand Up @@ -499,6 +501,70 @@ def format(self, formatter, subset=None, na_rep: Optional[str] = None) -> "Style
self._display_funcs[(i, j)] = formatter
return self

def set_td_classes(self, classes: DataFrame) -> "Styler":
"""
Add string based CSS class names to data cells that will appear within the
`Styler` HTML result. These classes are added within specified `<td>` elements.

Parameters
----------
classes : DataFrame
DataFrame containing strings that will be translated to CSS classes,
mapped by identical column and index values that must exist on the
underlying `Styler` data. None, NaN values, and empty strings will
be ignored and not affect the rendered HTML.

Returns
-------
self : Styler

Examples
--------
>>> df = pd.DataFrame(data=[[1, 2, 3], [4, 5, 6]], columns=["A", "B", "C"])
>>> classes = pd.DataFrame([
... ["min-val red", "", "blue"],
... ["red", None, "blue max-val"]
... ], index=df.index, columns=df.columns)
>>> df.style.set_td_classes(classes)

Using `MultiIndex` columns and a `classes` `DataFrame` as a subset of the
underlying,

>>> df = pd.DataFrame([[1,2],[3,4]], index=["a", "b"],
... columns=[["level0", "level0"], ["level1a", "level1b"]])
>>> classes = pd.DataFrame(["min-val"], index=["a"],
... columns=[["level0"],["level1a"]])
>>> df.style.set_td_classes(classes)

Form of the output with new additional css classes,

>>> df = pd.DataFrame([[1]])
>>> css = pd.DataFrame(["other-class"])
>>> s = Styler(df, uuid="_", cell_ids=False).set_td_classes(css)
>>> s.hide_index().render()
'<style type="text/css" ></style>'
'<table id="T__" >'
' <thead>'
' <tr><th class="col_heading level0 col0" >0</th></tr>'
' </thead>'
' <tbody>'
' <tr><td class="data row0 col0 other-class" >1</td></tr>'
' </tbody>'
'</table>'

"""
classes = classes.reindex_like(self.data)

mask = (classes.isna()) | (classes.eq(""))
self.cell_context["data"] = {
r: {c: [str(classes.iloc[r, c])]}
for r, rn in enumerate(classes.index)
for c, cn in enumerate(classes.columns)
if not mask.iloc[r, c]
}

return self

def render(self, **kwargs) -> str:
"""
Render the built up styles to HTML.
Expand Down Expand Up @@ -609,6 +675,7 @@ def clear(self) -> None:
Returns None.
"""
self.ctx.clear()
self.cell_context = {}
self._todo = []

def _compute(self):
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/io/formats/test_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,27 @@ def test_no_cell_ids(self):
s = styler.render() # render twice to ensure ctx is not updated
assert s.find('<td class="data row0 col0" >') != -1

@pytest.mark.parametrize(
"classes",
[
DataFrame(
data=[["", "test-class"], [np.nan, None]],
columns=["A", "B"],
index=["a", "b"],
),
DataFrame(data=[["test-class"]], columns=["B"], index=["a"]),
DataFrame(data=[["test-class", "unused"]], columns=["B", "C"], index=["a"]),
],
)
def test_set_data_classes(self, classes):
# GH 36159
df = DataFrame(data=[[0, 1], [2, 3]], columns=["A", "B"], index=["a", "b"])
s = Styler(df, uuid="_", cell_ids=False).set_td_classes(classes).render()
assert '<td class="data row0 col0" >0</td>' in s
assert '<td class="data row0 col1 test-class" >1</td>' in s
assert '<td class="data row1 col0" >2</td>' in s
assert '<td class="data row1 col1" >3</td>' in s

def test_colspan_w3(self):
# GH 36223
df = pd.DataFrame(data=[[1, 2]], columns=[["l0", "l0"], ["l1a", "l1b"]])
Expand Down