Skip to content

Commit 9a9fcf6

Browse files
authored
BUG: fix Series.apply(..., by_row), v2. (#53601)
1 parent 6d50df1 commit 9a9fcf6

File tree

6 files changed

+199
-31
lines changed

6 files changed

+199
-31
lines changed

doc/source/whatsnew/v2.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Other enhancements
113113
- :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`)
114114
- Added :meth:`ExtensionArray.interpolate` used by :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` (:issue:`53659`)
115115
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
116-
- Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`).
116+
- Added a new parameter ``by_row`` to :meth:`Series.apply` and :meth:`DataFrame.apply`. When set to ``False`` the supplied callables will always operate on the whole Series or DataFrame (:issue:`53400`, :issue:`53601`).
117117
- Groupby aggregations (such as :meth:`DataFrameGroupby.sum`) now can preserve the dtype of the input instead of casting to ``float64`` (:issue:`44952`)
118118
- Improved error message when :meth:`DataFrameGroupBy.agg` failed (:issue:`52930`)
119119
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)

pandas/core/apply.py

+64-5
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def frame_apply(
8181
axis: Axis = 0,
8282
raw: bool = False,
8383
result_type: str | None = None,
84+
by_row: Literal[False, "compat"] = "compat",
8485
args=None,
8586
kwargs=None,
8687
) -> FrameApply:
@@ -100,6 +101,7 @@ def frame_apply(
100101
func,
101102
raw=raw,
102103
result_type=result_type,
104+
by_row=by_row,
103105
args=args,
104106
kwargs=kwargs,
105107
)
@@ -115,11 +117,16 @@ def __init__(
115117
raw: bool,
116118
result_type: str | None,
117119
*,
120+
by_row: Literal[False, "compat", "_compat"] = "compat",
118121
args,
119122
kwargs,
120123
) -> None:
121124
self.obj = obj
122125
self.raw = raw
126+
127+
assert by_row is False or by_row in ["compat", "_compat"]
128+
self.by_row = by_row
129+
123130
self.args = args or ()
124131
self.kwargs = kwargs or {}
125132

@@ -304,7 +311,14 @@ def agg_or_apply_list_like(
304311
func = cast(List[AggFuncTypeBase], self.func)
305312
kwargs = self.kwargs
306313
if op_name == "apply":
307-
kwargs = {**kwargs, "by_row": False}
314+
if isinstance(self, FrameApply):
315+
by_row = self.by_row
316+
317+
elif isinstance(self, SeriesApply):
318+
by_row = "_compat" if self.by_row else False
319+
else:
320+
by_row = False
321+
kwargs = {**kwargs, "by_row": by_row}
308322

309323
if getattr(obj, "axis", 0) == 1:
310324
raise NotImplementedError("axis other than 0 is not supported")
@@ -397,7 +411,10 @@ def agg_or_apply_dict_like(
397411

398412
obj = self.obj
399413
func = cast(AggFuncTypeDict, self.func)
400-
kwargs = {"by_row": False} if op_name == "apply" else {}
414+
kwargs = {}
415+
if op_name == "apply":
416+
by_row = "_compat" if self.by_row else False
417+
kwargs.update({"by_row": by_row})
401418

402419
if getattr(obj, "axis", 0) == 1:
403420
raise NotImplementedError("axis other than 0 is not supported")
@@ -678,6 +695,23 @@ def agg_axis(self) -> Index:
678695
class FrameApply(NDFrameApply):
679696
obj: DataFrame
680697

698+
def __init__(
699+
self,
700+
obj: AggObjType,
701+
func: AggFuncType,
702+
raw: bool,
703+
result_type: str | None,
704+
*,
705+
by_row: Literal[False, "compat"] = False,
706+
args,
707+
kwargs,
708+
) -> None:
709+
if by_row is not False and by_row != "compat":
710+
raise ValueError(f"by_row={by_row} not allowed")
711+
super().__init__(
712+
obj, func, raw, result_type, by_row=by_row, args=args, kwargs=kwargs
713+
)
714+
681715
# ---------------------------------------------------------------
682716
# Abstract Methods
683717

@@ -1067,15 +1101,15 @@ def infer_to_same_shape(self, results: ResType, res_index: Index) -> DataFrame:
10671101
class SeriesApply(NDFrameApply):
10681102
obj: Series
10691103
axis: AxisInt = 0
1070-
by_row: bool # only relevant for apply()
1104+
by_row: Literal[False, "compat", "_compat"] # only relevant for apply()
10711105

10721106
def __init__(
10731107
self,
10741108
obj: Series,
10751109
func: AggFuncType,
10761110
*,
10771111
convert_dtype: bool | lib.NoDefault = lib.no_default,
1078-
by_row: bool = True,
1112+
by_row: Literal[False, "compat", "_compat"] = "compat",
10791113
args,
10801114
kwargs,
10811115
) -> None:
@@ -1090,13 +1124,13 @@ def __init__(
10901124
stacklevel=find_stack_level(),
10911125
)
10921126
self.convert_dtype = convert_dtype
1093-
self.by_row = by_row
10941127

10951128
super().__init__(
10961129
obj,
10971130
func,
10981131
raw=False,
10991132
result_type=None,
1133+
by_row=by_row,
11001134
args=args,
11011135
kwargs=kwargs,
11021136
)
@@ -1115,6 +1149,9 @@ def apply(self) -> DataFrame | Series:
11151149
# if we are a string, try to dispatch
11161150
return self.apply_str()
11171151

1152+
if self.by_row == "_compat":
1153+
return self.apply_compat()
1154+
11181155
# self.func is Callable
11191156
return self.apply_standard()
11201157

@@ -1149,6 +1186,28 @@ def apply_empty_result(self) -> Series:
11491186
obj, method="apply"
11501187
)
11511188

1189+
def apply_compat(self):
1190+
"""compat apply method for funcs in listlikes and dictlikes.
1191+
1192+
Used for each callable when giving listlikes and dictlikes of callables to
1193+
apply. Needed for compatibility with Pandas < v2.1.
1194+
1195+
.. versionadded:: 2.1.0
1196+
"""
1197+
obj = self.obj
1198+
func = self.func
1199+
1200+
if callable(func):
1201+
f = com.get_cython_func(func)
1202+
if f and not self.args and not self.kwargs:
1203+
return obj.apply(func, by_row=False)
1204+
1205+
try:
1206+
result = obj.apply(func, by_row="compat")
1207+
except (ValueError, AttributeError, TypeError):
1208+
result = obj.apply(func, by_row=False)
1209+
return result
1210+
11521211
def apply_standard(self) -> DataFrame | Series:
11531212
# caller is responsible for ensuring that f is Callable
11541213
func = cast(Callable, self.func)

pandas/core/frame.py

+13
Original file line numberDiff line numberDiff line change
@@ -9634,6 +9634,7 @@ def apply(
96349634
raw: bool = False,
96359635
result_type: Literal["expand", "reduce", "broadcast"] | None = None,
96369636
args=(),
9637+
by_row: Literal[False, "compat"] = "compat",
96379638
**kwargs,
96389639
):
96399640
"""
@@ -9682,6 +9683,17 @@ def apply(
96829683
args : tuple
96839684
Positional arguments to pass to `func` in addition to the
96849685
array/series.
9686+
by_row : False or "compat", default "compat"
9687+
Only has an effect when ``func`` is a listlike or dictlike of funcs
9688+
and the func isn't a string.
9689+
If "compat", will if possible first translate the func into pandas
9690+
methods (e.g. ``Series().apply(np.sum)`` will be translated to
9691+
``Series().sum()``). If that doesn't work, will try call to apply again with
9692+
``by_row=True`` and if that fails, will call apply again with
9693+
``by_row=False`` (backward compatible).
9694+
If False, the funcs will be passed the whole Series at once.
9695+
9696+
.. versionadded:: 2.1.0
96859697
**kwargs
96869698
Additional keyword arguments to pass as keywords arguments to
96879699
`func`.
@@ -9781,6 +9793,7 @@ def apply(
97819793
axis=axis,
97829794
raw=raw,
97839795
result_type=result_type,
9796+
by_row=by_row,
97849797
args=args,
97859798
kwargs=kwargs,
97869799
)

pandas/core/series.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -4538,7 +4538,7 @@ def apply(
45384538
convert_dtype: bool | lib.NoDefault = lib.no_default,
45394539
args: tuple[Any, ...] = (),
45404540
*,
4541-
by_row: bool = True,
4541+
by_row: Literal[False, "compat"] = "compat",
45424542
**kwargs,
45434543
) -> DataFrame | Series:
45444544
"""
@@ -4562,14 +4562,20 @@ def apply(
45624562
preserved for some extension array dtypes, such as Categorical.
45634563
45644564
.. deprecated:: 2.1.0
4565-
The convert_dtype has been deprecated. Do ``ser.astype(object).apply()``
4565+
``convert_dtype`` has been deprecated. Do ``ser.astype(object).apply()``
45664566
instead if you want ``convert_dtype=False``.
45674567
args : tuple
45684568
Positional arguments passed to func after the series value.
4569-
by_row : bool, default True
4569+
by_row : False or "compat", default "compat"
4570+
If ``"compat"`` and func is a callable, func will be passed each element of
4571+
the Series, like ``Series.map``. If func is a list or dict of
4572+
callables, will first try to translate each func into pandas methods. If
4573+
that doesn't work, will try call to apply again with ``by_row="compat"``
4574+
and if that fails, will call apply again with ``by_row=False``
4575+
(backward compatible).
45704576
If False, the func will be passed the whole Series at once.
4571-
If True, will func will be passed each element of the Series, like
4572-
Series.map (backward compatible).
4577+
4578+
``by_row`` has no effect when ``func`` is a string.
45734579
45744580
.. versionadded:: 2.1.0
45754581
**kwargs

pandas/tests/apply/test_frame_apply.py

+96
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,50 @@ def test_infer_row_shape():
667667
assert result == (6, 2)
668668

669669

670+
@pytest.mark.parametrize(
671+
"ops, by_row, expected",
672+
[
673+
({"a": lambda x: x + 1}, "compat", DataFrame({"a": [2, 3]})),
674+
({"a": lambda x: x + 1}, False, DataFrame({"a": [2, 3]})),
675+
({"a": lambda x: x.sum()}, "compat", Series({"a": 3})),
676+
({"a": lambda x: x.sum()}, False, Series({"a": 3})),
677+
(
678+
{"a": ["sum", np.sum, lambda x: x.sum()]},
679+
"compat",
680+
DataFrame({"a": [3, 3, 3]}, index=["sum", "sum", "<lambda>"]),
681+
),
682+
(
683+
{"a": ["sum", np.sum, lambda x: x.sum()]},
684+
False,
685+
DataFrame({"a": [3, 3, 3]}, index=["sum", "sum", "<lambda>"]),
686+
),
687+
({"a": lambda x: 1}, "compat", DataFrame({"a": [1, 1]})),
688+
({"a": lambda x: 1}, False, Series({"a": 1})),
689+
],
690+
)
691+
def test_dictlike_lambda(ops, by_row, expected):
692+
# GH53601
693+
df = DataFrame({"a": [1, 2]})
694+
result = df.apply(ops, by_row=by_row)
695+
tm.assert_equal(result, expected)
696+
697+
698+
@pytest.mark.parametrize(
699+
"ops",
700+
[
701+
{"a": lambda x: x + 1},
702+
{"a": lambda x: x.sum()},
703+
{"a": ["sum", np.sum, lambda x: x.sum()]},
704+
{"a": lambda x: 1},
705+
],
706+
)
707+
def test_dictlike_lambda_raises(ops):
708+
# GH53601
709+
df = DataFrame({"a": [1, 2]})
710+
with pytest.raises(ValueError, match="by_row=True not allowed"):
711+
df.apply(ops, by_row=True)
712+
713+
670714
def test_with_dictlike_columns():
671715
# GH 17602
672716
df = DataFrame([[1, 2], [1, 2]], columns=["a", "b"])
@@ -716,6 +760,58 @@ def test_with_dictlike_columns_with_infer():
716760
tm.assert_frame_equal(result, expected)
717761

718762

763+
@pytest.mark.parametrize(
764+
"ops, by_row, expected",
765+
[
766+
([lambda x: x + 1], "compat", DataFrame({("a", "<lambda>"): [2, 3]})),
767+
([lambda x: x + 1], False, DataFrame({("a", "<lambda>"): [2, 3]})),
768+
([lambda x: x.sum()], "compat", DataFrame({"a": [3]}, index=["<lambda>"])),
769+
([lambda x: x.sum()], False, DataFrame({"a": [3]}, index=["<lambda>"])),
770+
(
771+
["sum", np.sum, lambda x: x.sum()],
772+
"compat",
773+
DataFrame({"a": [3, 3, 3]}, index=["sum", "sum", "<lambda>"]),
774+
),
775+
(
776+
["sum", np.sum, lambda x: x.sum()],
777+
False,
778+
DataFrame({"a": [3, 3, 3]}, index=["sum", "sum", "<lambda>"]),
779+
),
780+
(
781+
[lambda x: x + 1, lambda x: 3],
782+
"compat",
783+
DataFrame([[2, 3], [3, 3]], columns=[["a", "a"], ["<lambda>", "<lambda>"]]),
784+
),
785+
(
786+
[lambda x: 2, lambda x: 3],
787+
False,
788+
DataFrame({"a": [2, 3]}, ["<lambda>", "<lambda>"]),
789+
),
790+
],
791+
)
792+
def test_listlike_lambda(ops, by_row, expected):
793+
# GH53601
794+
df = DataFrame({"a": [1, 2]})
795+
result = df.apply(ops, by_row=by_row)
796+
tm.assert_equal(result, expected)
797+
798+
799+
@pytest.mark.parametrize(
800+
"ops",
801+
[
802+
[lambda x: x + 1],
803+
[lambda x: x.sum()],
804+
["sum", np.sum, lambda x: x.sum()],
805+
[lambda x: x + 1, lambda x: 3],
806+
],
807+
)
808+
def test_listlike_lambda_raises(ops):
809+
# GH53601
810+
df = DataFrame({"a": [1, 2]})
811+
with pytest.raises(ValueError, match="by_row=True not allowed"):
812+
df.apply(ops, by_row=True)
813+
814+
719815
def test_with_listlike_columns():
720816
# GH 17348
721817
df = DataFrame(

0 commit comments

Comments
 (0)