Skip to content

Commit e932ec9

Browse files
authored
BUG: .transform(...) with "first" and "last" fail when axis=1 (#46074)
1 parent 7dea5ae commit e932ec9

File tree

4 files changed

+15
-41
lines changed

4 files changed

+15
-41
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ Groupby/resample/rolling
382382
- Bug in :meth:`DataFrame.resample` ignoring ``closed="right"`` on :class:`TimedeltaIndex` (:issue:`45414`)
383383
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``func="size"`` and the input DataFrame has multiple columns (:issue:`27469`)
384384
- Bug in :meth:`.DataFrameGroupBy.size` and :meth:`.DataFrameGroupBy.transform` with ``func="size"`` produced incorrect results when ``axis=1`` (:issue:`45715`)
385+
- Bug in :meth:`.DataFrameGroupby.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
385386

386387
Reshaping
387388
^^^^^^^^^

pandas/core/groupby/generic.py

-9
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,6 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
472472
result.name = self.obj.name
473473
return result
474474

475-
def _can_use_transform_fast(self, func: str, result) -> bool:
476-
return True
477-
478475
def filter(self, func, dropna: bool = True, *args, **kwargs):
479476
"""
480477
Return a copy of a Series excluding elements from groups that
@@ -1184,12 +1181,6 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
11841181
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
11851182
)
11861183

1187-
def _can_use_transform_fast(self, func: str, result) -> bool:
1188-
return func == "size" or (
1189-
isinstance(result, DataFrame)
1190-
and result.columns.equals(self._obj_with_exclusions.columns)
1191-
)
1192-
11931184
def _define_paths(self, func, *args, **kwargs):
11941185
if isinstance(func, str):
11951186
fast_path = lambda group: getattr(group, func)(*args, **kwargs)

pandas/core/groupby/groupby.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1650,11 +1650,7 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
16501650
with com.temp_setattr(self, "observed", True):
16511651
result = getattr(self, func)(*args, **kwargs)
16521652

1653-
if self._can_use_transform_fast(func, result):
1654-
return self._wrap_transform_fast_result(result)
1655-
1656-
# only reached for DataFrameGroupBy
1657-
return self._transform_general(func, *args, **kwargs)
1653+
return self._wrap_transform_fast_result(result)
16581654

16591655
@final
16601656
def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT:

pandas/tests/groupby/transform/test_transform.py

+13-27
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@
2121
)
2222
import pandas._testing as tm
2323
from pandas.core.groupby.base import maybe_normalize_deprecated_kernels
24-
from pandas.core.groupby.generic import (
25-
DataFrameGroupBy,
26-
SeriesGroupBy,
27-
)
24+
from pandas.core.groupby.generic import DataFrameGroupBy
2825

2926

3027
def assert_fp_equal(a, b):
@@ -195,10 +192,8 @@ def test_transform_axis_1_reducer(request, reduction_func):
195192
# GH#45715
196193
if reduction_func in (
197194
"corrwith",
198-
"first",
199195
"idxmax",
200196
"idxmin",
201-
"last",
202197
"ngroup",
203198
"nth",
204199
):
@@ -418,45 +413,36 @@ def test_transform_select_columns(df):
418413
tm.assert_frame_equal(result, expected)
419414

420415

421-
@pytest.mark.parametrize("duplicates", [True, False])
422-
def test_transform_exclude_nuisance(df, duplicates):
416+
def test_transform_exclude_nuisance(df):
423417
# case that goes through _transform_item_by_item
424418

425-
if duplicates:
426-
# make sure we work with duplicate columns GH#41427
427-
df.columns = ["A", "C", "C", "D"]
419+
df.columns = ["A", "B", "B", "D"]
428420

429421
# this also tests orderings in transform between
430422
# series/frame to make sure it's consistent
431423
expected = {}
432424
grouped = df.groupby("A")
433425

434-
gbc = grouped["C"]
435-
warn = FutureWarning if duplicates else None
436-
with tm.assert_produces_warning(warn, match="Dropping invalid columns"):
437-
expected["C"] = gbc.transform(np.mean)
438-
if duplicates:
439-
# squeeze 1-column DataFrame down to Series
440-
expected["C"] = expected["C"]["C"]
426+
gbc = grouped["B"]
427+
with tm.assert_produces_warning(FutureWarning, match="Dropping invalid columns"):
428+
expected["B"] = gbc.transform(lambda x: np.mean(x))
429+
# squeeze 1-column DataFrame down to Series
430+
expected["B"] = expected["B"]["B"]
441431

442-
assert isinstance(gbc.obj, DataFrame)
443-
assert isinstance(gbc, DataFrameGroupBy)
444-
else:
445-
assert isinstance(gbc, SeriesGroupBy)
446-
assert isinstance(gbc.obj, Series)
432+
assert isinstance(gbc.obj, DataFrame)
433+
assert isinstance(gbc, DataFrameGroupBy)
447434

448435
expected["D"] = grouped["D"].transform(np.mean)
449436
expected = DataFrame(expected)
450437
with tm.assert_produces_warning(FutureWarning, match="Dropping invalid columns"):
451-
result = df.groupby("A").transform(np.mean)
438+
result = df.groupby("A").transform(lambda x: np.mean(x))
452439

453440
tm.assert_frame_equal(result, expected)
454441

455442

456443
def test_transform_function_aliases(df):
457-
with tm.assert_produces_warning(FutureWarning, match="Dropping invalid columns"):
458-
result = df.groupby("A").transform("mean")
459-
expected = df.groupby("A").transform(np.mean)
444+
result = df.groupby("A").transform("mean")
445+
expected = df.groupby("A").transform(np.mean)
460446
tm.assert_frame_equal(result, expected)
461447

462448
result = df.groupby("A")["C"].transform("mean")

0 commit comments

Comments
 (0)