diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index b16ca0a80c5b4..010c8a22dcc9f 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -132,9 +132,8 @@ Plotting Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ +- Bug in :meth:`DataFrame.groupby` does not always maintain column index name for ``any``, ``all``, ``bfill``, ``ffill``, ``shift`` (:issue:`29764`) - -- - Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index c50b753cf3293..bc46cc041cbf3 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1728,9 +1728,10 @@ def _wrap_aggregated_output( ------- DataFrame """ + idx_name = output.pop("idx_name", None) indexed_output = {key.position: val for key, val in output.items()} name = self._obj_with_exclusions._get_axis(1 - self.axis).name - columns = Index([key.label for key in output], name=name) + columns = Index([key.label for key in output], name=idx_name) result = self.obj._constructor(indexed_output) result.columns = columns @@ -1762,8 +1763,9 @@ def _wrap_transformed_output( ------- DataFrame """ + idx_name = output.pop("idx_name", None) indexed_output = {key.position: val for key, val in output.items()} - columns = Index(key.label for key in output) + columns = Index([key.label for key in output], name=idx_name) result = self.obj._constructor(indexed_output) result.columns = columns diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index ac45222625569..5a20af9607cd9 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2450,8 +2450,11 @@ def _get_cythonized_result( grouper = self.grouper labels, _, ngroups = grouper.group_info - output: Dict[base.OutputKey, np.ndarray] = {} + output: Dict[Union[base.OutputKey, str], Union[np.ndarray, str]] = {} base_func = getattr(libgroupby, how) + obj = self._selected_obj + if isinstance(obj, DataFrame): + output["idx_name"] = obj.columns.name error_msg = "" for idx, obj in enumerate(self._iterate_slices()): diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 8c51ebf89f5c0..19dd8a1cee355 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -2069,3 +2069,26 @@ def test_group_on_two_row_multiindex_returns_one_tuple_key(): assert len(result) == 1 key = (1, 2) assert (result[key] == expected[key]).all() + + +@pytest.mark.parametrize("func", ["sum", "any", "shift"]) +def test_groupby_column_index_name_lost(func): + # GH: 29764 groupby loses index sometimes + expected = pd.Index(["a"], name="idx") + df = pd.DataFrame([[1]], columns=expected) + df_grouped = df.groupby([1]) + result = getattr(df_grouped, func)().columns + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("func", ["ffill", "bfill"]) +def test_groupby_column_index_name_lost_fill_funcs(func): + # GH: 29764 groupby loses index sometimes + df = pd.DataFrame( + [[1, 1.0, -1.0], [1, np.nan, np.nan], [1, 2.0, -2.0]], + columns=pd.Index(["type", "a", "b"], name="idx"), + ) + df_grouped = df.groupby(["type"])[["a", "b"]] + result = getattr(df_grouped, func)().columns + expected = pd.Index(["a", "b"], name="idx") + tm.assert_index_equal(result, expected)