diff --git a/doc/source/whatsnew/v1.2.3.rst b/doc/source/whatsnew/v1.2.3.rst index f72ee78bf243a..99e997189d7b8 100644 --- a/doc/source/whatsnew/v1.2.3.rst +++ b/doc/source/whatsnew/v1.2.3.rst @@ -24,6 +24,8 @@ Fixed regressions Passing ``ascending=None`` is still considered invalid, and the new error message suggests a proper usage (``ascending`` must be a boolean or a list-like boolean). +- Fixed regression in :meth:`DataFrame.transform` and :meth:`Series.transform` giving incorrect column labels when passed a dictionary with a mix of list and non-list values (:issue:`40018`) +- .. --------------------------------------------------------------------------- diff --git a/pandas/core/apply.py b/pandas/core/apply.py index db4203e5158ef..970629f4abfe9 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -280,7 +280,7 @@ def transform_dict_like(self, func): if len(func) == 0: raise ValueError("No transform functions were provided") - self.validate_dictlike_arg("transform", obj, func) + func = self.normalize_dictlike_arg("transform", obj, func) results: Dict[Hashable, FrameOrSeriesUnion] = {} for name, how in func.items(): @@ -421,32 +421,17 @@ def agg_dict_like(self, _axis: int) -> FrameOrSeriesUnion: ------- Result of aggregation. """ + from pandas.core.reshape.concat import concat + obj = self.obj arg = cast(AggFuncTypeDict, self.f) - is_aggregator = lambda x: isinstance(x, (list, tuple, dict)) - if _axis != 0: # pragma: no cover raise ValueError("Can only pass dict with axis=0") selected_obj = obj._selected_obj - self.validate_dictlike_arg("agg", selected_obj, arg) - - # if we have a dict of any non-scalars - # eg. {'A' : ['mean']}, normalize all to - # be list-likes - # Cannot use arg.values() because arg may be a Series - if any(is_aggregator(x) for _, x in arg.items()): - new_arg: AggFuncTypeDict = {} - for k, v in arg.items(): - if not isinstance(v, (tuple, list, dict)): - new_arg[k] = [v] - else: - new_arg[k] = v - arg = new_arg - - from pandas.core.reshape.concat import concat + arg = self.normalize_dictlike_arg("agg", selected_obj, arg) if selected_obj.ndim == 1: # key only used for output @@ -540,14 +525,15 @@ def maybe_apply_multiple(self) -> Optional[FrameOrSeriesUnion]: return None return self.obj.aggregate(self.f, self.axis, *self.args, **self.kwargs) - def validate_dictlike_arg( + def normalize_dictlike_arg( self, how: str, obj: FrameOrSeriesUnion, func: AggFuncTypeDict - ) -> None: + ) -> AggFuncTypeDict: """ - Raise if dict-like argument is invalid. + Handler for dict-like argument. Ensures that necessary columns exist if obj is a DataFrame, and - that a nested renamer is not passed. + that a nested renamer is not passed. Also normalizes to all lists + when values consists of a mix of list and non-lists. """ assert how in ("apply", "agg", "transform") @@ -567,6 +553,23 @@ def validate_dictlike_arg( cols_sorted = list(safe_sort(list(cols))) raise KeyError(f"Column(s) {cols_sorted} do not exist") + is_aggregator = lambda x: isinstance(x, (list, tuple, dict)) + + # if we have a dict of any non-scalars + # eg. {'A' : ['mean']}, normalize all to + # be list-likes + # Cannot use func.values() because arg may be a Series + if any(is_aggregator(x) for _, x in func.items()): + new_func: AggFuncTypeDict = {} + for k, v in func.items(): + if not is_aggregator(v): + # mypy can't realize v is not a list here + new_func[k] = [v] # type:ignore[list-item] + else: + new_func[k] = v + func = new_func + return func + class FrameApply(Apply): obj: DataFrame diff --git a/pandas/tests/apply/test_frame_transform.py b/pandas/tests/apply/test_frame_transform.py index 1888ddd8ec4aa..47bc69656a597 100644 --- a/pandas/tests/apply/test_frame_transform.py +++ b/pandas/tests/apply/test_frame_transform.py @@ -103,6 +103,17 @@ def test_transform_dictlike(axis, float_frame, box): tm.assert_frame_equal(result, expected) +def test_transform_dictlike_mixed(): + # GH 40018 - mix of lists and non-lists in values of a dictionary + df = DataFrame({"a": [1, 2], "b": [1, 4], "c": [1, 4]}) + result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"}) + expected = DataFrame( + [[1.0, 1, 1.0], [2.0, 4, 2.0]], + columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( "ops", [ diff --git a/pandas/tests/apply/test_series_transform.py b/pandas/tests/apply/test_series_transform.py index e67ea4f14e4ac..24d619cb2bbb1 100644 --- a/pandas/tests/apply/test_series_transform.py +++ b/pandas/tests/apply/test_series_transform.py @@ -2,6 +2,8 @@ import pytest from pandas import ( + DataFrame, + MultiIndex, Series, concat, ) @@ -55,6 +57,17 @@ def test_transform_dictlike(string_series, box): tm.assert_frame_equal(result, expected) +def test_transform_dictlike_mixed(): + # GH 40018 - mix of lists and non-lists in values of a dictionary + df = Series([1, 4]) + result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"}) + expected = DataFrame( + [[1.0, 1, 1.0], [2.0, 4, 2.0]], + columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]), + ) + tm.assert_frame_equal(result, expected) + + def test_transform_wont_agg(string_series): # GH 35964 # we are trying to transform with an aggregator