Skip to content

Commit ab5b38d

Browse files
authored
BUG/CLN: Decouple Series/DataFrame.transform (#35964)
1 parent bed9656 commit ab5b38d

File tree

10 files changed

+507
-169
lines changed

10 files changed

+507
-169
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ Other
344344
^^^^^
345345
- Bug in :meth:`DataFrame.replace` and :meth:`Series.replace` incorrectly raising ``AssertionError`` instead of ``ValueError`` when invalid parameter combinations are passed (:issue:`36045`)
346346
- Bug in :meth:`DataFrame.replace` and :meth:`Series.replace` with numeric values and string ``to_replace`` (:issue:`34789`)
347+
- Bug in :meth:`Series.transform` would give incorrect results or raise when the argument ``func`` was dictionary (:issue:`35811`)
347348
-
348349

349350
.. ---------------------------------------------------------------------------

pandas/core/aggregation.py

+97-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
Union,
1919
)
2020

21-
from pandas._typing import AggFuncType, FrameOrSeries, Label
21+
from pandas._typing import AggFuncType, Axis, FrameOrSeries, Label
2222

2323
from pandas.core.dtypes.common import is_dict_like, is_list_like
24+
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
2425

2526
from pandas.core.base import SpecificationError
2627
import pandas.core.common as com
@@ -384,3 +385,98 @@ def validate_func_kwargs(
384385
if not columns:
385386
raise TypeError(no_arg_message)
386387
return columns, func
388+
389+
390+
def transform(
391+
obj: FrameOrSeries, func: AggFuncType, axis: Axis, *args, **kwargs,
392+
) -> FrameOrSeries:
393+
"""
394+
Transform a DataFrame or Series
395+
396+
Parameters
397+
----------
398+
obj : DataFrame or Series
399+
Object to compute the transform on.
400+
func : string, function, list, or dictionary
401+
Function(s) to compute the transform with.
402+
axis : {0 or 'index', 1 or 'columns'}
403+
Axis along which the function is applied:
404+
405+
* 0 or 'index': apply function to each column.
406+
* 1 or 'columns': apply function to each row.
407+
408+
Returns
409+
-------
410+
DataFrame or Series
411+
Result of applying ``func`` along the given axis of the
412+
Series or DataFrame.
413+
414+
Raises
415+
------
416+
ValueError
417+
If the transform function fails or does not transform.
418+
"""
419+
from pandas.core.reshape.concat import concat
420+
421+
is_series = obj.ndim == 1
422+
423+
if obj._get_axis_number(axis) == 1:
424+
assert not is_series
425+
return transform(obj.T, func, 0, *args, **kwargs).T
426+
427+
if isinstance(func, list):
428+
if is_series:
429+
func = {com.get_callable_name(v) or v: v for v in func}
430+
else:
431+
func = {col: func for col in obj}
432+
433+
if isinstance(func, dict):
434+
if not is_series:
435+
cols = sorted(set(func.keys()) - set(obj.columns))
436+
if len(cols) > 0:
437+
raise SpecificationError(f"Column(s) {cols} do not exist")
438+
439+
if any(isinstance(v, dict) for v in func.values()):
440+
# GH 15931 - deprecation of renaming keys
441+
raise SpecificationError("nested renamer is not supported")
442+
443+
results = {}
444+
for name, how in func.items():
445+
colg = obj._gotitem(name, ndim=1)
446+
try:
447+
results[name] = transform(colg, how, 0, *args, **kwargs)
448+
except Exception as e:
449+
if str(e) == "Function did not transform":
450+
raise e
451+
452+
# combine results
453+
if len(results) == 0:
454+
raise ValueError("Transform function failed")
455+
return concat(results, axis=1)
456+
457+
# func is either str or callable
458+
try:
459+
if isinstance(func, str):
460+
result = obj._try_aggregate_string_function(func, *args, **kwargs)
461+
else:
462+
f = obj._get_cython_func(func)
463+
if f and not args and not kwargs:
464+
result = getattr(obj, f)()
465+
else:
466+
try:
467+
result = obj.apply(func, args=args, **kwargs)
468+
except Exception:
469+
result = func(obj, *args, **kwargs)
470+
except Exception:
471+
raise ValueError("Transform function failed")
472+
473+
# Functions that transform may return empty Series/DataFrame
474+
# when the dtype is not appropriate
475+
if isinstance(result, (ABCSeries, ABCDataFrame)) and result.empty:
476+
raise ValueError("Transform function failed")
477+
if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals(
478+
obj.index
479+
):
480+
raise ValueError("Function did not transform")
481+
482+
return result

pandas/core/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import builtins
66
import textwrap
7-
from typing import Any, Dict, FrozenSet, List, Optional, Union
7+
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Union
88

99
import numpy as np
1010

@@ -560,7 +560,7 @@ def _aggregate_multiple_funcs(self, arg, _axis):
560560
) from err
561561
return result
562562

563-
def _get_cython_func(self, arg: str) -> Optional[str]:
563+
def _get_cython_func(self, arg: Callable) -> Optional[str]:
564564
"""
565565
if we define an internal function for this argument, return it
566566
"""

pandas/core/frame.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from pandas._libs import algos as libalgos, lib, properties
4646
from pandas._libs.lib import no_default
4747
from pandas._typing import (
48+
AggFuncType,
4849
ArrayLike,
4950
Axes,
5051
Axis,
@@ -116,7 +117,7 @@
116117

117118
from pandas.core import algorithms, common as com, nanops, ops
118119
from pandas.core.accessor import CachedAccessor
119-
from pandas.core.aggregation import reconstruct_func, relabel_result
120+
from pandas.core.aggregation import reconstruct_func, relabel_result, transform
120121
from pandas.core.arrays import Categorical, ExtensionArray
121122
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin as DatetimeLikeArray
122123
from pandas.core.arrays.sparse import SparseFrameAccessor
@@ -7462,15 +7463,16 @@ def _aggregate(self, arg, axis=0, *args, **kwargs):
74627463
agg = aggregate
74637464

74647465
@doc(
7465-
NDFrame.transform,
7466+
_shared_docs["transform"],
74667467
klass=_shared_doc_kwargs["klass"],
74677468
axis=_shared_doc_kwargs["axis"],
74687469
)
7469-
def transform(self, func, axis=0, *args, **kwargs) -> DataFrame:
7470-
axis = self._get_axis_number(axis)
7471-
if axis == 1:
7472-
return self.T.transform(func, *args, **kwargs).T
7473-
return super().transform(func, *args, **kwargs)
7470+
def transform(
7471+
self, func: AggFuncType, axis: Axis = 0, *args, **kwargs
7472+
) -> DataFrame:
7473+
result = transform(self, func, axis, *args, **kwargs)
7474+
assert isinstance(result, DataFrame)
7475+
return result
74747476

74757477
def apply(self, func, axis=0, raw=False, result_type=None, args=(), **kwds):
74767478
"""

pandas/core/generic.py

-74
Original file line numberDiff line numberDiff line change
@@ -10648,80 +10648,6 @@ def ewm(
1064810648
times=times,
1064910649
)
1065010650

10651-
@doc(klass=_shared_doc_kwargs["klass"], axis="")
10652-
def transform(self, func, *args, **kwargs):
10653-
"""
10654-
Call ``func`` on self producing a {klass} with transformed values.
10655-
10656-
Produced {klass} will have same axis length as self.
10657-
10658-
Parameters
10659-
----------
10660-
func : function, str, list or dict
10661-
Function to use for transforming the data. If a function, must either
10662-
work when passed a {klass} or when passed to {klass}.apply.
10663-
10664-
Accepted combinations are:
10665-
10666-
- function
10667-
- string function name
10668-
- list of functions and/or function names, e.g. ``[np.exp, 'sqrt']``
10669-
- dict of axis labels -> functions, function names or list of such.
10670-
{axis}
10671-
*args
10672-
Positional arguments to pass to `func`.
10673-
**kwargs
10674-
Keyword arguments to pass to `func`.
10675-
10676-
Returns
10677-
-------
10678-
{klass}
10679-
A {klass} that must have the same length as self.
10680-
10681-
Raises
10682-
------
10683-
ValueError : If the returned {klass} has a different length than self.
10684-
10685-
See Also
10686-
--------
10687-
{klass}.agg : Only perform aggregating type operations.
10688-
{klass}.apply : Invoke function on a {klass}.
10689-
10690-
Examples
10691-
--------
10692-
>>> df = pd.DataFrame({{'A': range(3), 'B': range(1, 4)}})
10693-
>>> df
10694-
A B
10695-
0 0 1
10696-
1 1 2
10697-
2 2 3
10698-
>>> df.transform(lambda x: x + 1)
10699-
A B
10700-
0 1 2
10701-
1 2 3
10702-
2 3 4
10703-
10704-
Even though the resulting {klass} must have the same length as the
10705-
input {klass}, it is possible to provide several input functions:
10706-
10707-
>>> s = pd.Series(range(3))
10708-
>>> s
10709-
0 0
10710-
1 1
10711-
2 2
10712-
dtype: int64
10713-
>>> s.transform([np.sqrt, np.exp])
10714-
sqrt exp
10715-
0 0.000000 1.000000
10716-
1 1.000000 2.718282
10717-
2 1.414214 7.389056
10718-
"""
10719-
result = self.agg(func, *args, **kwargs)
10720-
if is_scalar(result) or len(result) != len(self):
10721-
raise ValueError("transforms cannot produce aggregated results")
10722-
10723-
return result
10724-
1072510651
# ----------------------------------------------------------------------
1072610652
# Misc methods
1072710653

pandas/core/series.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pandas._libs import lib, properties, reshape, tslibs
2626
from pandas._libs.lib import no_default
2727
from pandas._typing import (
28+
AggFuncType,
2829
ArrayLike,
2930
Axis,
3031
DtypeObj,
@@ -89,6 +90,7 @@
8990
from pandas.core.indexes.timedeltas import TimedeltaIndex
9091
from pandas.core.indexing import check_bool_indexer
9192
from pandas.core.internals import SingleBlockManager
93+
from pandas.core.shared_docs import _shared_docs
9294
from pandas.core.sorting import ensure_key_mapped
9395
from pandas.core.strings import StringMethods
9496
from pandas.core.tools.datetimes import to_datetime
@@ -4081,14 +4083,16 @@ def aggregate(self, func=None, axis=0, *args, **kwargs):
40814083
agg = aggregate
40824084

40834085
@doc(
4084-
NDFrame.transform,
4086+
_shared_docs["transform"],
40854087
klass=_shared_doc_kwargs["klass"],
40864088
axis=_shared_doc_kwargs["axis"],
40874089
)
4088-
def transform(self, func, axis=0, *args, **kwargs):
4089-
# Validate the axis parameter
4090-
self._get_axis_number(axis)
4091-
return super().transform(func, *args, **kwargs)
4090+
def transform(
4091+
self, func: AggFuncType, axis: Axis = 0, *args, **kwargs
4092+
) -> FrameOrSeriesUnion:
4093+
from pandas.core.aggregation import transform
4094+
4095+
return transform(self, func, axis, *args, **kwargs)
40924096

40934097
def apply(self, func, convert_dtype=True, args=(), **kwds):
40944098
"""

pandas/core/shared_docs.py

+69
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,72 @@
257257
1 b B E 3
258258
2 c B E 5
259259
"""
260+
261+
_shared_docs[
262+
"transform"
263+
] = """\
264+
Call ``func`` on self producing a {klass} with transformed values.
265+
266+
Produced {klass} will have same axis length as self.
267+
268+
Parameters
269+
----------
270+
func : function, str, list or dict
271+
Function to use for transforming the data. If a function, must either
272+
work when passed a {klass} or when passed to {klass}.apply.
273+
274+
Accepted combinations are:
275+
276+
- function
277+
- string function name
278+
- list of functions and/or function names, e.g. ``[np.exp, 'sqrt']``
279+
- dict of axis labels -> functions, function names or list of such.
280+
{axis}
281+
*args
282+
Positional arguments to pass to `func`.
283+
**kwargs
284+
Keyword arguments to pass to `func`.
285+
286+
Returns
287+
-------
288+
{klass}
289+
A {klass} that must have the same length as self.
290+
291+
Raises
292+
------
293+
ValueError : If the returned {klass} has a different length than self.
294+
295+
See Also
296+
--------
297+
{klass}.agg : Only perform aggregating type operations.
298+
{klass}.apply : Invoke function on a {klass}.
299+
300+
Examples
301+
--------
302+
>>> df = pd.DataFrame({{'A': range(3), 'B': range(1, 4)}})
303+
>>> df
304+
A B
305+
0 0 1
306+
1 1 2
307+
2 2 3
308+
>>> df.transform(lambda x: x + 1)
309+
A B
310+
0 1 2
311+
1 2 3
312+
2 3 4
313+
314+
Even though the resulting {klass} must have the same length as the
315+
input {klass}, it is possible to provide several input functions:
316+
317+
>>> s = pd.Series(range(3))
318+
>>> s
319+
0 0
320+
1 1
321+
2 2
322+
dtype: int64
323+
>>> s.transform([np.sqrt, np.exp])
324+
sqrt exp
325+
0 0.000000 1.000000
326+
1 1.000000 2.718282
327+
2 1.414214 7.389056
328+
"""

0 commit comments

Comments
 (0)