Skip to content

Commit ab34dd6

Browse files
authored
REGR: get_group raising with axis=1 (#54882)
1 parent c7325d7 commit ab34dd6

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

doc/source/whatsnew/v2.1.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ including other versions of pandas.
1414
Fixed regressions
1515
~~~~~~~~~~~~~~~~~
1616
- Fixed regression in :func:`read_csv` when ``usecols`` is given and ``dtypes`` is a dict for ``engine="python"`` (:issue:`54868`)
17+
- Fixed regression in :meth:`.GroupBy.get_group` raising for ``axis=1`` (:issue:`54858`)
1718
- Fixed regression in :meth:`DataFrame.__setitem__` raising ``AssertionError`` when setting a :class:`Series` with a partial :class:`MultiIndex` (:issue:`54875`)
1819
- Fixed regression when comparing a :class:`Series` with ``datetime64`` dtype with ``None`` (:issue:`54870`)
1920

pandas/core/groupby/groupby.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,8 @@ def get_group(self, name, obj=None) -> DataFrame | Series:
10801080
raise KeyError(name)
10811081

10821082
if obj is None:
1083-
return self._selected_obj.iloc[inds]
1083+
indexer = inds if self.axis == 0 else (slice(None), inds)
1084+
return self._selected_obj.iloc[indexer]
10841085
else:
10851086
warnings.warn(
10861087
"obj is deprecated and will be removed in a future version. "

pandas/tests/groupby/test_groupby.py

+23
Original file line numberDiff line numberDiff line change
@@ -3187,3 +3187,26 @@ def test_depr_get_group_len_1_list_likes(test_series, kwarg, value, name, warn):
31873187
else:
31883188
expected = DataFrame({"b": [3, 4]}, index=Index([1, 1], name="a"))
31893189
tm.assert_equal(result, expected)
3190+
3191+
3192+
def test_get_group_axis_1():
3193+
# GH#54858
3194+
df = DataFrame(
3195+
{
3196+
"col1": [0, 3, 2, 3],
3197+
"col2": [4, 1, 6, 7],
3198+
"col3": [3, 8, 2, 10],
3199+
"col4": [1, 13, 6, 15],
3200+
"col5": [-4, 5, 6, -7],
3201+
}
3202+
)
3203+
with tm.assert_produces_warning(FutureWarning, match="deprecated"):
3204+
grouped = df.groupby(axis=1, by=[1, 2, 3, 2, 1])
3205+
result = grouped.get_group(1)
3206+
expected = DataFrame(
3207+
{
3208+
"col1": [0, 3, 2, 3],
3209+
"col5": [-4, 5, 6, -7],
3210+
}
3211+
)
3212+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)