Skip to content

Commit d9c3777

Browse files
authored
REF: Decouple Series.apply from Series.agg (#53400)
1 parent 7c6b54f commit d9c3777

File tree

4 files changed

+227
-94
lines changed

4 files changed

+227
-94
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ Other enhancements
101101
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
102102
- :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`)
103103
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
104+
- Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`).
104105
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)
105106
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
106107
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)

pandas/core/apply.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Iterable,
1717
Iterator,
1818
List,
19+
Literal,
1920
Sequence,
2021
cast,
2122
)
@@ -288,6 +289,11 @@ def agg_list_like(self) -> DataFrame | Series:
288289
-------
289290
Result of aggregation.
290291
"""
292+
return self.agg_or_apply_list_like(op_name="agg")
293+
294+
def agg_or_apply_list_like(
295+
self, op_name: Literal["agg", "apply"]
296+
) -> DataFrame | Series:
291297
from pandas.core.groupby.generic import (
292298
DataFrameGroupBy,
293299
SeriesGroupBy,
@@ -296,6 +302,9 @@ def agg_list_like(self) -> DataFrame | Series:
296302

297303
obj = self.obj
298304
func = cast(List[AggFuncTypeBase], self.func)
305+
kwargs = self.kwargs
306+
if op_name == "apply":
307+
kwargs = {**kwargs, "by_row": False}
299308

300309
if getattr(obj, "axis", 0) == 1:
301310
raise NotImplementedError("axis other than 0 is not supported")
@@ -313,8 +322,6 @@ def agg_list_like(self) -> DataFrame | Series:
313322
keys = []
314323

315324
is_groupby = isinstance(obj, (DataFrameGroupBy, SeriesGroupBy))
316-
is_ser_or_df = isinstance(obj, (ABCDataFrame, ABCSeries))
317-
this_args = [self.axis, *self.args] if is_ser_or_df else self.args
318325

319326
context_manager: ContextManager
320327
if is_groupby:
@@ -323,12 +330,19 @@ def agg_list_like(self) -> DataFrame | Series:
323330
context_manager = com.temp_setattr(obj, "as_index", True)
324331
else:
325332
context_manager = nullcontext()
333+
334+
def include_axis(colg) -> bool:
335+
return isinstance(colg, ABCDataFrame) or (
336+
isinstance(colg, ABCSeries) and op_name == "agg"
337+
)
338+
326339
with context_manager:
327340
# degenerate case
328341
if selected_obj.ndim == 1:
329342
for a in func:
330343
colg = obj._gotitem(selected_obj.name, ndim=1, subset=selected_obj)
331-
new_res = colg.aggregate(a, *this_args, **self.kwargs)
344+
args = [self.axis, *self.args] if include_axis(colg) else self.args
345+
new_res = getattr(colg, op_name)(a, *args, **kwargs)
332346
results.append(new_res)
333347

334348
# make sure we find a good name
@@ -339,7 +353,8 @@ def agg_list_like(self) -> DataFrame | Series:
339353
indices = []
340354
for index, col in enumerate(selected_obj):
341355
colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index])
342-
new_res = colg.aggregate(func, *this_args, **self.kwargs)
356+
args = [self.axis, *self.args] if include_axis(colg) else self.args
357+
new_res = getattr(colg, op_name)(func, *args, **kwargs)
343358
results.append(new_res)
344359
indices.append(index)
345360
keys = selected_obj.columns.take(indices)
@@ -366,15 +381,23 @@ def agg_dict_like(self) -> DataFrame | Series:
366381
-------
367382
Result of aggregation.
368383
"""
384+
return self.agg_or_apply_dict_like(op_name="agg")
385+
386+
def agg_or_apply_dict_like(
387+
self, op_name: Literal["agg", "apply"]
388+
) -> DataFrame | Series:
369389
from pandas import Index
370390
from pandas.core.groupby.generic import (
371391
DataFrameGroupBy,
372392
SeriesGroupBy,
373393
)
374394
from pandas.core.reshape.concat import concat
375395

396+
assert op_name in ["agg", "apply"]
397+
376398
obj = self.obj
377399
func = cast(AggFuncTypeDict, self.func)
400+
kwargs = {"by_row": False} if op_name == "apply" else {}
378401

379402
if getattr(obj, "axis", 0) == 1:
380403
raise NotImplementedError("axis other than 0 is not supported")
@@ -387,7 +410,7 @@ def agg_dict_like(self) -> DataFrame | Series:
387410
selected_obj = obj._selected_obj
388411
selection = obj._selection
389412

390-
func = self.normalize_dictlike_arg("agg", selected_obj, func)
413+
func = self.normalize_dictlike_arg(op_name, selected_obj, func)
391414

392415
is_groupby = isinstance(obj, (DataFrameGroupBy, SeriesGroupBy))
393416
context_manager: ContextManager
@@ -404,17 +427,18 @@ def agg_dict_like(self) -> DataFrame | Series:
404427
)
405428

406429
# Numba Groupby engine/engine-kwargs passthrough
407-
kwargs = {}
408430
if is_groupby:
409431
engine = self.kwargs.get("engine", None)
410432
engine_kwargs = self.kwargs.get("engine_kwargs", None)
411-
kwargs = {"engine": engine, "engine_kwargs": engine_kwargs}
433+
kwargs.update({"engine": engine, "engine_kwargs": engine_kwargs})
412434

413435
with context_manager:
414436
if selected_obj.ndim == 1:
415437
# key only used for output
416438
colg = obj._gotitem(selection, ndim=1)
417-
result_data = [colg.agg(how, **kwargs) for _, how in func.items()]
439+
result_data = [
440+
getattr(colg, op_name)(how, **kwargs) for _, how in func.items()
441+
]
418442
result_index = list(func.keys())
419443
elif is_non_unique_col:
420444
# key used for column selection and output
@@ -429,7 +453,9 @@ def agg_dict_like(self) -> DataFrame | Series:
429453
label_to_indices[label].append(index)
430454

431455
key_data = [
432-
selected_obj._ixs(indice, axis=1).agg(how, **kwargs)
456+
getattr(selected_obj._ixs(indice, axis=1), op_name)(
457+
how, **kwargs
458+
)
433459
for label, indices in label_to_indices.items()
434460
for indice in indices
435461
]
@@ -439,7 +465,7 @@ def agg_dict_like(self) -> DataFrame | Series:
439465
else:
440466
# key used for column selection and output
441467
result_data = [
442-
obj._gotitem(key, ndim=1).agg(how, **kwargs)
468+
getattr(obj._gotitem(key, ndim=1), op_name)(how, **kwargs)
443469
for key, how in func.items()
444470
]
445471
result_index = list(func.keys())
@@ -535,7 +561,7 @@ def apply_str(self) -> DataFrame | Series:
535561
self.kwargs["axis"] = self.axis
536562
return self._apply_str(obj, func, *self.args, **self.kwargs)
537563

538-
def apply_multiple(self) -> DataFrame | Series:
564+
def apply_list_or_dict_like(self) -> DataFrame | Series:
539565
"""
540566
Compute apply in case of a list-like or dict-like.
541567
@@ -551,9 +577,9 @@ def apply_multiple(self) -> DataFrame | Series:
551577
kwargs = self.kwargs
552578

553579
if is_dict_like(func):
554-
result = self.agg_dict_like()
580+
result = self.agg_or_apply_dict_like(op_name="apply")
555581
else:
556-
result = self.agg_list_like()
582+
result = self.agg_or_apply_list_like(op_name="apply")
557583

558584
result = reconstruct_and_relabel_result(result, func, **kwargs)
559585

@@ -692,9 +718,9 @@ def values(self):
692718

693719
def apply(self) -> DataFrame | Series:
694720
"""compute the results"""
695-
# dispatch to agg
721+
# dispatch to handle list-like or dict-like
696722
if is_list_like(self.func):
697-
return self.apply_multiple()
723+
return self.apply_list_or_dict_like()
698724

699725
# all empty
700726
if len(self.columns) == 0 and len(self.index) == 0:
@@ -1041,13 +1067,15 @@ def infer_to_same_shape(self, results: ResType, res_index: Index) -> DataFrame:
10411067
class SeriesApply(NDFrameApply):
10421068
obj: Series
10431069
axis: AxisInt = 0
1070+
by_row: bool # only relevant for apply()
10441071

10451072
def __init__(
10461073
self,
10471074
obj: Series,
10481075
func: AggFuncType,
10491076
*,
10501077
convert_dtype: bool | lib.NoDefault = lib.no_default,
1078+
by_row: bool = True,
10511079
args,
10521080
kwargs,
10531081
) -> None:
@@ -1062,6 +1090,7 @@ def __init__(
10621090
stacklevel=find_stack_level(),
10631091
)
10641092
self.convert_dtype = convert_dtype
1093+
self.by_row = by_row
10651094

10661095
super().__init__(
10671096
obj,
@@ -1078,9 +1107,9 @@ def apply(self) -> DataFrame | Series:
10781107
if len(obj) == 0:
10791108
return self.apply_empty_result()
10801109

1081-
# dispatch to agg
1110+
# dispatch to handle list-like or dict-like
10821111
if is_list_like(self.func):
1083-
return self.apply_multiple()
1112+
return self.apply_list_or_dict_like()
10841113

10851114
if isinstance(self.func, str):
10861115
# if we are a string, try to dispatch
@@ -1126,6 +1155,8 @@ def apply_standard(self) -> DataFrame | Series:
11261155
if isinstance(func, np.ufunc):
11271156
with np.errstate(all="ignore"):
11281157
return func(obj, *self.args, **self.kwargs)
1158+
elif not self.by_row:
1159+
return func(obj, *self.args, **self.kwargs)
11291160

11301161
if self.args or self.kwargs:
11311162
# _map_values does not support args/kwargs

pandas/core/series.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -4496,6 +4496,8 @@ def apply(
44964496
func: AggFuncType,
44974497
convert_dtype: bool | lib.NoDefault = lib.no_default,
44984498
args: tuple[Any, ...] = (),
4499+
*,
4500+
by_row: bool = True,
44994501
**kwargs,
45004502
) -> DataFrame | Series:
45014503
"""
@@ -4523,6 +4525,12 @@ def apply(
45234525
instead if you want ``convert_dtype=False``.
45244526
args : tuple
45254527
Positional arguments passed to func after the series value.
4528+
by_row : bool, default True
4529+
If False, the func will be passed the whole Series at once.
4530+
If True, will func will be passed each element of the Series, like
4531+
Series.map (backward compatible).
4532+
4533+
.. versionadded:: 2.1.0
45264534
**kwargs
45274535
Additional keyword arguments passed to func.
45284536
@@ -4611,7 +4619,12 @@ def apply(
46114619
dtype: float64
46124620
"""
46134621
return SeriesApply(
4614-
self, func, convert_dtype=convert_dtype, args=args, kwargs=kwargs
4622+
self,
4623+
func,
4624+
convert_dtype=convert_dtype,
4625+
by_row=by_row,
4626+
args=args,
4627+
kwargs=kwargs,
46154628
).apply()
46164629

46174630
def _reindex_indexer(

0 commit comments

Comments
 (0)