Skip to content

Commit 897c720

Browse files
attack68feefladder
authored andcommitted
BUG: Styler.apply consistently manages Series return objects aligning labels. (pandas-dev#42014)
1 parent bb57265 commit 897c720

File tree

3 files changed

+93
-32
lines changed

3 files changed

+93
-32
lines changed

doc/source/whatsnew/v1.4.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ Styler
365365
- Minor bug in :class:`.Styler` where the ``uuid`` at initialization maintained a floating underscore (:issue:`43037`)
366366
- Bug in :meth:`.Styler.to_html` where the ``Styler`` object was updated if the ``to_html`` method was called with some args (:issue:`43034`)
367367
- Bug in :meth:`.Styler.copy` where ``uuid`` was not previously copied (:issue:`40675`)
368+
- Bug in :meth:`Styler.apply` where functions which returned Series objects were not correctly handled in terms of aligning their index labels (:issue:`13657`, :issue:`42014`)
369+
-
368370

369371
Other
370372
^^^^^

pandas/io/formats/style.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ def _update_ctx(self, attrs: DataFrame) -> None:
10241024

10251025
for cn in attrs.columns:
10261026
for rn, c in attrs[[cn]].itertuples():
1027-
if not c:
1027+
if not c or pd.isna(c):
10281028
continue
10291029
css_list = maybe_convert_css_to_tuples(c)
10301030
i, j = self.index.get_loc(rn), self.columns.get_loc(cn)
@@ -1148,9 +1148,10 @@ def _apply(
11481148
subset = slice(None) if subset is None else subset
11491149
subset = non_reducing_slice(subset)
11501150
data = self.data.loc[subset]
1151-
if axis is not None:
1152-
result = data.apply(func, axis=axis, result_type="expand", **kwargs)
1153-
result.columns = data.columns
1151+
if axis in [0, "index"]:
1152+
result = data.apply(func, axis=0, **kwargs)
1153+
elif axis in [1, "columns"]:
1154+
result = data.T.apply(func, axis=0, **kwargs).T # see GH 42005
11541155
else:
11551156
result = func(data, **kwargs)
11561157
if not isinstance(result, DataFrame):
@@ -1166,19 +1167,28 @@ def _apply(
11661167
f"Expected shape: {data.shape}"
11671168
)
11681169
result = DataFrame(result, index=data.index, columns=data.columns)
1169-
elif not (
1170-
result.index.equals(data.index) and result.columns.equals(data.columns)
1171-
):
1172-
raise ValueError(
1173-
f"Result of {repr(func)} must have identical "
1174-
f"index and columns as the input"
1175-
)
11761170

1177-
if result.shape != data.shape:
1171+
if isinstance(result, Series):
11781172
raise ValueError(
1179-
f"Function {repr(func)} returned the wrong shape.\n"
1180-
f"Result has shape: {result.shape}\n"
1181-
f"Expected shape: {data.shape}"
1173+
f"Function {repr(func)} resulted in the apply method collapsing to a "
1174+
f"Series.\nUsually, this is the result of the function returning a "
1175+
f"single value, instead of list-like."
1176+
)
1177+
msg = (
1178+
f"Function {repr(func)} created invalid {{0}} labels.\nUsually, this is "
1179+
f"the result of the function returning a "
1180+
f"{'Series' if axis is not None else 'DataFrame'} which contains invalid "
1181+
f"labels, or returning an incorrectly shaped, list-like object which "
1182+
f"cannot be mapped to labels, possibly due to applying the function along "
1183+
f"the wrong axis.\n"
1184+
f"Result {{0}} has shape: {{1}}\n"
1185+
f"Expected {{0}} shape: {{2}}"
1186+
)
1187+
if not all(result.index.isin(data.index)):
1188+
raise ValueError(msg.format("index", result.index.shape, data.index.shape))
1189+
if not all(result.columns.isin(data.columns)):
1190+
raise ValueError(
1191+
msg.format("columns", result.columns.shape, data.columns.shape)
11821192
)
11831193
self._update_ctx(result)
11841194
return self
@@ -1198,14 +1208,17 @@ def apply(
11981208
Parameters
11991209
----------
12001210
func : function
1201-
``func`` should take a Series if ``axis`` in [0,1] and return an object
1202-
of same length, also with identical index if the object is a Series.
1211+
``func`` should take a Series if ``axis`` in [0,1] and return a list-like
1212+
object of same length, or a Series, not necessarily of same length, with
1213+
valid index labels considering ``subset``.
12031214
``func`` should take a DataFrame if ``axis`` is ``None`` and return either
1204-
an ndarray with the same shape or a DataFrame with identical columns and
1205-
index.
1215+
an ndarray with the same shape or a DataFrame, not necessarily of the same
1216+
shape, with valid index and columns labels considering ``subset``.
12061217
12071218
.. versionchanged:: 1.3.0
12081219
1220+
.. versionchanged:: 1.4.0
1221+
12091222
axis : {0 or 'index', 1 or 'columns', None}, default 0
12101223
Apply to each column (``axis=0`` or ``'index'``), to each row
12111224
(``axis=1`` or ``'columns'``), or to the entire DataFrame at once
@@ -1260,6 +1273,13 @@ def apply(
12601273
>>> df.style.apply(highlight_max, color='red', subset=(slice(0,5,2), "A"))
12611274
... # doctest: +SKIP
12621275
1276+
Using a function which returns a Series / DataFrame of unequal length but
1277+
containing valid index labels
1278+
1279+
>>> df = pd.DataFrame([[1, 2], [3, 4], [4, 6]], index=["A1", "A2", "Total"])
1280+
>>> total_style = pd.Series("font-weight: bold;", index=["Total"])
1281+
>>> df.style.apply(lambda s: total_style) # doctest: +SKIP
1282+
12631283
See `Table Visualization <../../user_guide/style.ipynb>`_ user guide for
12641284
more details.
12651285
"""

pandas/tests/io/formats/style/test_style.py

+52-13
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,40 @@ def test_apply_axis(self):
550550
result._compute()
551551
assert result.ctx == expected
552552

553+
@pytest.mark.parametrize("axis", [0, 1])
554+
def test_apply_series_return(self, axis):
555+
# GH 42014
556+
df = DataFrame([[1, 2], [3, 4]], index=["X", "Y"], columns=["X", "Y"])
557+
558+
# test Series return where len(Series) < df.index or df.columns but labels OK
559+
func = lambda s: pd.Series(["color: red;"], index=["Y"])
560+
result = df.style.apply(func, axis=axis)._compute().ctx
561+
assert result[(1, 1)] == [("color", "red")]
562+
assert result[(1 - axis, axis)] == [("color", "red")]
563+
564+
# test Series return where labels align but different order
565+
func = lambda s: pd.Series(["color: red;", "color: blue;"], index=["Y", "X"])
566+
result = df.style.apply(func, axis=axis)._compute().ctx
567+
assert result[(0, 0)] == [("color", "blue")]
568+
assert result[(1, 1)] == [("color", "red")]
569+
assert result[(1 - axis, axis)] == [("color", "red")]
570+
assert result[(axis, 1 - axis)] == [("color", "blue")]
571+
572+
@pytest.mark.parametrize("index", [False, True])
573+
@pytest.mark.parametrize("columns", [False, True])
574+
def test_apply_dataframe_return(self, index, columns):
575+
# GH 42014
576+
df = DataFrame([[1, 2], [3, 4]], index=["X", "Y"], columns=["X", "Y"])
577+
idxs = ["X", "Y"] if index else ["Y"]
578+
cols = ["X", "Y"] if columns else ["Y"]
579+
df_styles = DataFrame("color: red;", index=idxs, columns=cols)
580+
result = df.style.apply(lambda x: df_styles, axis=None)._compute().ctx
581+
582+
assert result[(1, 1)] == [("color", "red")] # (Y,Y) styles always present
583+
assert (result[(0, 1)] == [("color", "red")]) is index # (X,Y) only if index
584+
assert (result[(1, 0)] == [("color", "red")]) is columns # (Y,X) only if cols
585+
assert (result[(0, 0)] == [("color", "red")]) is (index and columns) # (X,X)
586+
553587
@pytest.mark.parametrize(
554588
"slice_",
555589
[
@@ -794,24 +828,28 @@ def test_export(self):
794828
style2.to_html()
795829

796830
def test_bad_apply_shape(self):
797-
df = DataFrame([[1, 2], [3, 4]])
798-
msg = "returned the wrong shape"
799-
with pytest.raises(ValueError, match=msg):
800-
df.style._apply(lambda x: "x", subset=pd.IndexSlice[[0, 1], :])
831+
df = DataFrame([[1, 2], [3, 4]], index=["A", "B"], columns=["X", "Y"])
801832

833+
msg = "resulted in the apply method collapsing to a Series."
802834
with pytest.raises(ValueError, match=msg):
803-
df.style._apply(lambda x: [""], subset=pd.IndexSlice[[0, 1], :])
835+
df.style._apply(lambda x: "x")
804836

805-
with pytest.raises(ValueError, match=msg):
837+
msg = "created invalid {} labels"
838+
with pytest.raises(ValueError, match=msg.format("index")):
839+
df.style._apply(lambda x: [""])
840+
841+
with pytest.raises(ValueError, match=msg.format("index")):
806842
df.style._apply(lambda x: ["", "", "", ""])
807843

808-
with pytest.raises(ValueError, match=msg):
809-
df.style._apply(lambda x: ["", "", ""], subset=1)
844+
with pytest.raises(ValueError, match=msg.format("index")):
845+
df.style._apply(lambda x: pd.Series(["a:v;", ""], index=["A", "C"]), axis=0)
810846

811-
msg = "Length mismatch: Expected axis has 3 elements"
812-
with pytest.raises(ValueError, match=msg):
847+
with pytest.raises(ValueError, match=msg.format("columns")):
813848
df.style._apply(lambda x: ["", "", ""], axis=1)
814849

850+
with pytest.raises(ValueError, match=msg.format("columns")):
851+
df.style._apply(lambda x: pd.Series(["a:v;", ""], index=["X", "Z"]), axis=1)
852+
815853
msg = "returned ndarray with wrong shape"
816854
with pytest.raises(ValueError, match=msg):
817855
df.style._apply(lambda x: np.array([[""], [""]]), axis=None)
@@ -828,12 +866,13 @@ def f(x):
828866
with pytest.raises(TypeError, match=msg):
829867
df.style._apply(f, axis=None)
830868

831-
def test_apply_bad_labels(self):
869+
@pytest.mark.parametrize("axis", ["index", "columns"])
870+
def test_apply_bad_labels(self, axis):
832871
def f(x):
833-
return DataFrame(index=[1, 2], columns=["a", "b"])
872+
return DataFrame(**{axis: ["bad", "labels"]})
834873

835874
df = DataFrame([[1, 2], [3, 4]])
836-
msg = "must have identical index and columns as the input"
875+
msg = f"created invalid {axis} labels."
837876
with pytest.raises(ValueError, match=msg):
838877
df.style._apply(f, axis=None)
839878

0 commit comments

Comments
 (0)