Skip to content

Commit 138d575

Browse files
authored
BUG: DataFrameGroupBy.transform with axis=1 fails (pandas-dev#36308) (pandas-dev#36350)
1 parent 4cfa97a commit 138d575

File tree

5 files changed

+32
-11
lines changed

5 files changed

+32
-11
lines changed

doc/source/whatsnew/v1.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ Groupby/resample/rolling
601601
- Bug in :meth:`Rolling.median` and :meth:`Rolling.quantile` returned wrong values for :class:`BaseIndexer` subclasses with non-monotonic starting or ending points for windows (:issue:`37153`)
602602
- Bug in :meth:`DataFrame.groupby` dropped ``nan`` groups from result with ``dropna=False`` when grouping over a single column (:issue:`35646`, :issue:`35542`)
603603
- Bug in :meth:`DataFrameGroupBy.head`, :meth:`DataFrameGroupBy.tail`, :meth:`SeriesGroupBy.head`, and :meth:`SeriesGroupBy.tail` would raise when used with ``axis=1`` (:issue:`9772`)
604-
604+
- Bug in :meth:`DataFrameGroupBy.transform` would raise when used with ``axis=1`` and a transformation kernel (e.g. "shift") (:issue:`36308`)
605605

606606
Reshaping
607607
^^^^^^^^^

pandas/core/groupby/generic.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1675,11 +1675,16 @@ def _wrap_transformed_output(
16751675
DataFrame
16761676
"""
16771677
indexed_output = {key.position: val for key, val in output.items()}
1678-
columns = Index(key.label for key in output)
1679-
columns.name = self.obj.columns.name
1680-
16811678
result = self.obj._constructor(indexed_output)
1682-
result.columns = columns
1679+
1680+
if self.axis == 1:
1681+
result = result.T
1682+
result.columns = self.obj.columns
1683+
else:
1684+
columns = Index(key.label for key in output)
1685+
columns.name = self.obj.columns.name
1686+
result.columns = columns
1687+
16831688
result.index = self.obj.index
16841689

16851690
return result

pandas/core/groupby/groupby.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2365,7 +2365,7 @@ def cumcount(self, ascending: bool = True):
23652365
dtype: int64
23662366
"""
23672367
with group_selection_context(self):
2368-
index = self._selected_obj.index
2368+
index = self._selected_obj._get_axis(self.axis)
23692369
cumcounts = self._cumcount_array(ascending=ascending)
23702370
return self._obj_1d_constructor(cumcounts, index)
23712371

@@ -2706,8 +2706,8 @@ def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None, axis=0
27062706
fill_method = "pad"
27072707
limit = 0
27082708
filled = getattr(self, fill_method)(limit=limit)
2709-
fill_grp = filled.groupby(self.grouper.codes)
2710-
shifted = fill_grp.shift(periods=periods, freq=freq)
2709+
fill_grp = filled.groupby(self.grouper.codes, axis=self.axis)
2710+
shifted = fill_grp.shift(periods=periods, freq=freq, axis=self.axis)
27112711
return (filled / shifted) - 1
27122712

27132713
@Substitution(name="groupby")

pandas/tests/frame/apply/test_frame_transform.py

-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ def test_transform_groupby_kernel(axis, float_frame, op):
2727
pytest.xfail("DataFrame.cumcount does not exist")
2828
if op == "tshift":
2929
pytest.xfail("Only works on time index and is deprecated")
30-
if axis == 1 or axis == "columns":
31-
pytest.xfail("GH 36308: groupby.transform with axis=1 is broken")
3230

3331
args = [0.0] if op == "fillna" else []
3432
if axis == 0 or axis == "index":

pandas/tests/groupby/transform/test_transform.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,25 @@ def test_transform_broadcast(tsframe, ts):
158158
assert_fp_equal(res.xs(idx), agged[idx])
159159

160160

161-
def test_transform_axis(tsframe):
161+
def test_transform_axis_1(transformation_func):
162+
# GH 36308
163+
if transformation_func == "tshift":
164+
pytest.xfail("tshift is deprecated")
165+
args = ("ffill",) if transformation_func == "fillna" else tuple()
166+
167+
df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"])
168+
result = df.groupby([0, 0, 1], axis=1).transform(transformation_func, *args)
169+
expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T
170+
171+
if transformation_func == "diff":
172+
# Result contains nans, so transpose coerces to float
173+
expected["b"] = expected["b"].astype("int64")
174+
175+
# cumcount returns Series; the rest are DataFrame
176+
tm.assert_equal(result, expected)
177+
178+
179+
def test_transform_axis_ts(tsframe):
162180

163181
# make sure that we are setting the axes
164182
# correctly when on axis=0 or 1

0 commit comments

Comments
 (0)