Skip to content

Commit ebd9906

Browse files
authored
CLN: aggregation.transform (#36478)
1 parent 98af5d7 commit ebd9906

File tree

6 files changed

+112
-26
lines changed

6 files changed

+112
-26
lines changed

doc/source/whatsnew/v1.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ Reshaping
424424
- Bug in :func:`union_indexes` where input index names are not preserved in some cases. Affects :func:`concat` and :class:`DataFrame` constructor (:issue:`13475`)
425425
- Bug in func :meth:`crosstab` when using multiple columns with ``margins=True`` and ``normalize=True`` (:issue:`35144`)
426426
- Bug in :meth:`DataFrame.agg` with ``func={'name':<FUNC>}`` incorrectly raising ``TypeError`` when ``DataFrame.columns==['Name']`` (:issue:`36212`)
427+
- Bug in :meth:`Series.transform` would give incorrect results or raise when the argument ``func`` was dictionary (:issue:`35811`)
427428
-
428429

429430
Sparse
@@ -446,7 +447,6 @@ Other
446447
- Bug in :meth:`DataFrame.replace` and :meth:`Series.replace` incorrectly raising ``AssertionError`` instead of ``ValueError`` when invalid parameter combinations are passed (:issue:`36045`)
447448
- Bug in :meth:`DataFrame.replace` and :meth:`Series.replace` with numeric values and string ``to_replace`` (:issue:`34789`)
448449
- Fixed metadata propagation in the :class:`Series.dt` accessor (:issue:`28283`)
449-
- Bug in :meth:`Series.transform` would give incorrect results or raise when the argument ``func`` was dictionary (:issue:`35811`)
450450
- Bug in :meth:`Index.union` behaving differently depending on whether operand is a :class:`Index` or other list-like (:issue:`36384`)
451451

452452
.. ---------------------------------------------------------------------------

pandas/core/aggregation.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,17 @@
1717
Sequence,
1818
Tuple,
1919
Union,
20+
cast,
2021
)
2122

22-
from pandas._typing import AggFuncType, Axis, FrameOrSeries, Label
23+
from pandas._typing import (
24+
AggFuncType,
25+
AggFuncTypeBase,
26+
Axis,
27+
FrameOrSeries,
28+
FrameOrSeriesUnion,
29+
Label,
30+
)
2331

2432
from pandas.core.dtypes.common import is_dict_like, is_list_like
2533
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
@@ -391,7 +399,7 @@ def validate_func_kwargs(
391399

392400
def transform(
393401
obj: FrameOrSeries, func: AggFuncType, axis: Axis, *args, **kwargs
394-
) -> FrameOrSeries:
402+
) -> FrameOrSeriesUnion:
395403
"""
396404
Transform a DataFrame or Series
397405
@@ -424,16 +432,20 @@ def transform(
424432
assert not is_series
425433
return transform(obj.T, func, 0, *args, **kwargs).T
426434

427-
if isinstance(func, list):
435+
if is_list_like(func) and not is_dict_like(func):
436+
func = cast(List[AggFuncTypeBase], func)
437+
# Convert func equivalent dict
428438
if is_series:
429439
func = {com.get_callable_name(v) or v: v for v in func}
430440
else:
431441
func = {col: func for col in obj}
432442

433-
if isinstance(func, dict):
443+
if is_dict_like(func):
444+
func = cast(Dict[Label, Union[AggFuncTypeBase, List[AggFuncTypeBase]]], func)
434445
return transform_dict_like(obj, func, *args, **kwargs)
435446

436447
# func is either str or callable
448+
func = cast(AggFuncTypeBase, func)
437449
try:
438450
result = transform_str_or_callable(obj, func, *args, **kwargs)
439451
except Exception:
@@ -451,37 +463,52 @@ def transform(
451463
return result
452464

453465

454-
def transform_dict_like(obj, func, *args, **kwargs):
466+
def transform_dict_like(
467+
obj: FrameOrSeries,
468+
func: Dict[Label, Union[AggFuncTypeBase, List[AggFuncTypeBase]]],
469+
*args,
470+
**kwargs,
471+
):
455472
"""
456473
Compute transform in the case of a dict-like func
457474
"""
458475
from pandas.core.reshape.concat import concat
459476

477+
if len(func) == 0:
478+
raise ValueError("No transform functions were provided")
479+
460480
if obj.ndim != 1:
481+
# Check for missing columns on a frame
461482
cols = sorted(set(func.keys()) - set(obj.columns))
462483
if len(cols) > 0:
463484
raise SpecificationError(f"Column(s) {cols} do not exist")
464485

465-
if any(isinstance(v, dict) for v in func.values()):
486+
# Can't use func.values(); wouldn't work for a Series
487+
if any(is_dict_like(v) for _, v in func.items()):
466488
# GH 15931 - deprecation of renaming keys
467489
raise SpecificationError("nested renamer is not supported")
468490

469-
results = {}
491+
results: Dict[Label, FrameOrSeriesUnion] = {}
470492
for name, how in func.items():
471493
colg = obj._gotitem(name, ndim=1)
472494
try:
473495
results[name] = transform(colg, how, 0, *args, **kwargs)
474-
except Exception as e:
475-
if str(e) == "Function did not transform":
476-
raise e
496+
except Exception as err:
497+
if (
498+
str(err) == "Function did not transform"
499+
or str(err) == "No transform functions were provided"
500+
):
501+
raise err
477502

478503
# combine results
479504
if len(results) == 0:
480505
raise ValueError("Transform function failed")
481506
return concat(results, axis=1)
482507

483508

484-
def transform_str_or_callable(obj, func, *args, **kwargs):
509+
def transform_str_or_callable(
510+
obj: FrameOrSeries, func: AggFuncTypeBase, *args, **kwargs
511+
) -> FrameOrSeriesUnion:
485512
"""
486513
Compute transform in the case of a string or callable func
487514
"""

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7270,7 +7270,7 @@ def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame:
72707270

72717271
def _gotitem(
72727272
self,
7273-
key: Union[str, List[str]],
7273+
key: Union[Label, List[Label]],
72747274
ndim: int,
72757275
subset: Optional[FrameOrSeriesUnion] = None,
72767276
) -> FrameOrSeriesUnion:

pandas/core/shared_docs.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,17 @@
265265
266266
Parameters
267267
----------
268-
func : function, str, list or dict
268+
func : function, str, list-like or dict-like
269269
Function to use for transforming the data. If a function, must either
270-
work when passed a {klass} or when passed to {klass}.apply.
270+
work when passed a {klass} or when passed to {klass}.apply. If func
271+
is both list-like and dict-like, dict-like behavior takes precedence.
271272
272273
Accepted combinations are:
273274
274275
- function
275276
- string function name
276-
- list of functions and/or function names, e.g. ``[np.exp, 'sqrt']``
277-
- dict of axis labels -> functions, function names or list of such.
277+
- list-like of functions and/or function names, e.g. ``[np.exp, 'sqrt']``
278+
- dict-like of axis labels -> functions, function names or list-like of such.
278279
{axis}
279280
*args
280281
Positional arguments to pass to `func`.

pandas/tests/frame/apply/test_frame_transform.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66

7-
from pandas import DataFrame, MultiIndex
7+
from pandas import DataFrame, MultiIndex, Series
88
import pandas._testing as tm
99
from pandas.core.base import SpecificationError
1010
from pandas.core.groupby.base import transformation_kernels
@@ -41,9 +41,15 @@ def test_transform_groupby_kernel(axis, float_frame, op):
4141

4242

4343
@pytest.mark.parametrize(
44-
"ops, names", [([np.sqrt], ["sqrt"]), ([np.abs, np.sqrt], ["absolute", "sqrt"])]
44+
"ops, names",
45+
[
46+
([np.sqrt], ["sqrt"]),
47+
([np.abs, np.sqrt], ["absolute", "sqrt"]),
48+
(np.array([np.sqrt]), ["sqrt"]),
49+
(np.array([np.abs, np.sqrt]), ["absolute", "sqrt"]),
50+
],
4551
)
46-
def test_transform_list(axis, float_frame, ops, names):
52+
def test_transform_listlike(axis, float_frame, ops, names):
4753
# GH 35964
4854
other_axis = 1 if axis in {0, "index"} else 0
4955
with np.errstate(all="ignore"):
@@ -56,18 +62,41 @@ def test_transform_list(axis, float_frame, ops, names):
5662
tm.assert_frame_equal(result, expected)
5763

5864

59-
def test_transform_dict(axis, float_frame):
65+
@pytest.mark.parametrize("ops", [[], np.array([])])
66+
def test_transform_empty_listlike(float_frame, ops):
67+
with pytest.raises(ValueError, match="No transform functions were provided"):
68+
float_frame.transform(ops)
69+
70+
71+
@pytest.mark.parametrize("box", [dict, Series])
72+
def test_transform_dictlike(axis, float_frame, box):
6073
# GH 35964
6174
if axis == 0 or axis == "index":
6275
e = float_frame.columns[0]
6376
expected = float_frame[[e]].transform(np.abs)
6477
else:
6578
e = float_frame.index[0]
6679
expected = float_frame.iloc[[0]].transform(np.abs)
67-
result = float_frame.transform({e: np.abs}, axis=axis)
80+
result = float_frame.transform(box({e: np.abs}), axis=axis)
6881
tm.assert_frame_equal(result, expected)
6982

7083

84+
@pytest.mark.parametrize(
85+
"ops",
86+
[
87+
{},
88+
{"A": []},
89+
{"A": [], "B": "cumsum"},
90+
{"A": "cumsum", "B": []},
91+
{"A": [], "B": ["cumsum"]},
92+
{"A": ["cumsum"], "B": []},
93+
],
94+
)
95+
def test_transform_empty_dictlike(float_frame, ops):
96+
with pytest.raises(ValueError, match="No transform functions were provided"):
97+
float_frame.transform(ops)
98+
99+
71100
@pytest.mark.parametrize("use_apply", [True, False])
72101
def test_transform_udf(axis, float_frame, use_apply):
73102
# GH 35964

pandas/tests/series/apply/test_series_transform.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,15 @@ def test_transform_groupby_kernel(string_series, op):
3434

3535

3636
@pytest.mark.parametrize(
37-
"ops, names", [([np.sqrt], ["sqrt"]), ([np.abs, np.sqrt], ["absolute", "sqrt"])]
37+
"ops, names",
38+
[
39+
([np.sqrt], ["sqrt"]),
40+
([np.abs, np.sqrt], ["absolute", "sqrt"]),
41+
(np.array([np.sqrt]), ["sqrt"]),
42+
(np.array([np.abs, np.sqrt]), ["absolute", "sqrt"]),
43+
],
3844
)
39-
def test_transform_list(string_series, ops, names):
45+
def test_transform_listlike(string_series, ops, names):
4046
# GH 35964
4147
with np.errstate(all="ignore"):
4248
expected = concat([op(string_series) for op in ops], axis=1)
@@ -45,15 +51,38 @@ def test_transform_list(string_series, ops, names):
4551
tm.assert_frame_equal(result, expected)
4652

4753

48-
def test_transform_dict(string_series):
54+
@pytest.mark.parametrize("ops", [[], np.array([])])
55+
def test_transform_empty_listlike(string_series, ops):
56+
with pytest.raises(ValueError, match="No transform functions were provided"):
57+
string_series.transform(ops)
58+
59+
60+
@pytest.mark.parametrize("box", [dict, Series])
61+
def test_transform_dictlike(string_series, box):
4962
# GH 35964
5063
with np.errstate(all="ignore"):
5164
expected = concat([np.sqrt(string_series), np.abs(string_series)], axis=1)
5265
expected.columns = ["foo", "bar"]
53-
result = string_series.transform({"foo": np.sqrt, "bar": np.abs})
66+
result = string_series.transform(box({"foo": np.sqrt, "bar": np.abs}))
5467
tm.assert_frame_equal(result, expected)
5568

5669

70+
@pytest.mark.parametrize(
71+
"ops",
72+
[
73+
{},
74+
{"A": []},
75+
{"A": [], "B": ["cumsum"]},
76+
{"A": ["cumsum"], "B": []},
77+
{"A": [], "B": "cumsum"},
78+
{"A": "cumsum", "B": []},
79+
],
80+
)
81+
def test_transform_empty_dictlike(string_series, ops):
82+
with pytest.raises(ValueError, match="No transform functions were provided"):
83+
string_series.transform(ops)
84+
85+
5786
def test_transform_udf(axis, string_series):
5887
# GH 35964
5988
# via apply

0 commit comments

Comments
 (0)