Skip to content

Commit f670a67

Browse files
jbrockmendelJulianWgs
authored andcommitted
BUG: columns name retention in groupby methods (pandas-dev#41497)
1 parent 13b2129 commit f670a67

File tree

6 files changed

+58
-23
lines changed

6 files changed

+58
-23
lines changed

doc/source/whatsnew/v1.3.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,8 @@ Groupby/resample/rolling
11431143
- Bug in :class:`DataFrameGroupBy` aggregations incorrectly failing to drop columns with invalid dtypes for that aggregation when there are no valid columns (:issue:`41291`)
11441144
- Bug in :meth:`DataFrame.rolling.__iter__` where ``on`` was not assigned to the index of the resulting objects (:issue:`40373`)
11451145
- Bug in :meth:`.DataFrameGroupBy.transform` and :meth:`.DataFrameGroupBy.agg` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`41647`)
1146+
- Bug in :class:`DataFrameGroupBy` methods ``agg``, ``transform``, ``sum``, ``bfill``, ``ffill``, ``pad``, ``pct_change``, ``shift``, ``ohlc`` dropping ``.columns.names`` (:issue:`41497`)
1147+
11461148

11471149
Reshaping
11481150
^^^^^^^^^

pandas/core/apply.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def agg_list_like(self) -> FrameOrSeriesUnion:
348348

349349
# multiples
350350
else:
351+
indices = []
351352
for index, col in enumerate(selected_obj):
352353
colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index])
353354
try:
@@ -369,7 +370,9 @@ def agg_list_like(self) -> FrameOrSeriesUnion:
369370
raise
370371
else:
371372
results.append(new_res)
372-
keys.append(col)
373+
indices.append(index)
374+
375+
keys = selected_obj.columns.take(indices)
373376

374377
# if we are empty
375378
if not len(results):
@@ -407,6 +410,7 @@ def agg_dict_like(self) -> FrameOrSeriesUnion:
407410
-------
408411
Result of aggregation.
409412
"""
413+
from pandas import Index
410414
from pandas.core.reshape.concat import concat
411415

412416
obj = self.obj
@@ -443,8 +447,18 @@ def agg_dict_like(self) -> FrameOrSeriesUnion:
443447
keys_to_use = [k for k in keys if not results[k].empty]
444448
# Have to check, if at least one DataFrame is not empty.
445449
keys_to_use = keys_to_use if keys_to_use != [] else keys
450+
if selected_obj.ndim == 2:
451+
# keys are columns, so we can preserve names
452+
ktu = Index(keys_to_use)
453+
ktu._set_names(selected_obj.columns.names)
454+
# Incompatible types in assignment (expression has type "Index",
455+
# variable has type "List[Hashable]")
456+
keys_to_use = ktu # type: ignore[assignment]
457+
446458
axis = 0 if isinstance(obj, ABCSeries) else 1
447-
result = concat({k: results[k] for k in keys_to_use}, axis=axis)
459+
result = concat(
460+
{k: results[k] for k in keys_to_use}, axis=axis, keys=keys_to_use
461+
)
448462
elif any(is_ndframe):
449463
# There is a mix of NDFrames and scalars
450464
raise ValueError(

pandas/core/groupby/generic.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1020,13 +1020,15 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
10201020

10211021
if isinstance(sobj, Series):
10221022
# GH#35246 test_groupby_as_index_select_column_sum_empty_df
1023-
result.columns = [sobj.name]
1023+
result.columns = self._obj_with_exclusions.columns.copy()
10241024
else:
1025+
# Retain our column names
1026+
result.columns._set_names(
1027+
sobj.columns.names, level=list(range(sobj.columns.nlevels))
1028+
)
10251029
# select everything except for the last level, which is the one
10261030
# containing the name of the function(s), see GH#32040
1027-
result.columns = result.columns.rename(
1028-
[sobj.columns.name] * result.columns.nlevels
1029-
).droplevel(-1)
1031+
result.columns = result.columns.droplevel(-1)
10301032

10311033
if not self.as_index:
10321034
self._insert_inaxis_grouper_inplace(result)
@@ -1665,7 +1667,7 @@ def _wrap_transformed_output(
16651667
result.columns = self.obj.columns
16661668
else:
16671669
columns = Index(key.label for key in output)
1668-
columns.name = self.obj.columns.name
1670+
columns._set_names(self.obj._get_axis(1 - self.axis).names)
16691671
result.columns = columns
16701672

16711673
result.index = self.obj.index
@@ -1800,7 +1802,6 @@ def nunique(self, dropna: bool = True) -> DataFrame:
18001802
results = self._apply_to_column_groupbys(
18011803
lambda sgb: sgb.nunique(dropna), obj=obj
18021804
)
1803-
results.columns.names = obj.columns.names # TODO: do at higher level?
18041805

18051806
if not self.as_index:
18061807
results.index = Index(range(len(results)))

pandas/core/reshape/concat.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,13 @@ def __init__(
362362
clean_keys.append(k)
363363
clean_objs.append(v)
364364
objs = clean_objs
365-
name = getattr(keys, "name", None)
366-
keys = Index(clean_keys, name=name)
365+
366+
if isinstance(keys, MultiIndex):
367+
# TODO: retain levels?
368+
keys = type(keys).from_tuples(clean_keys, names=keys.names)
369+
else:
370+
name = getattr(keys, "name", None)
371+
keys = Index(clean_keys, name=name)
367372

368373
if len(objs) == 0:
369374
raise ValueError("All objects passed were None")

pandas/tests/groupby/aggregate/test_aggregate.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,13 @@ def test_agg_multiple_functions_same_name_with_ohlc_present():
300300
# ohlc expands dimensions, so different test to the above is required.
301301
df = DataFrame(
302302
np.random.randn(1000, 3),
303-
index=pd.date_range("1/1/2012", freq="S", periods=1000),
304-
columns=["A", "B", "C"],
303+
index=pd.date_range("1/1/2012", freq="S", periods=1000, name="dti"),
304+
columns=Index(["A", "B", "C"], name="alpha"),
305305
)
306306
result = df.resample("3T").agg(
307307
{"A": ["ohlc", partial(np.quantile, q=0.9999), partial(np.quantile, q=0.1111)]}
308308
)
309-
expected_index = pd.date_range("1/1/2012", freq="3T", periods=6)
309+
expected_index = pd.date_range("1/1/2012", freq="3T", periods=6, name="dti")
310310
expected_columns = MultiIndex.from_tuples(
311311
[
312312
("A", "ohlc", "open"),
@@ -315,7 +315,8 @@ def test_agg_multiple_functions_same_name_with_ohlc_present():
315315
("A", "ohlc", "close"),
316316
("A", "quantile", "A"),
317317
("A", "quantile", "A"),
318-
]
318+
],
319+
names=["alpha", None, None],
319320
)
320321
non_ohlc_expected_values = np.array(
321322
[df.resample("3T").A.quantile(q=q).values for q in [0.9999, 0.1111]]
@@ -901,14 +902,20 @@ def test_grouby_agg_loses_results_with_as_index_false_relabel_multiindex():
901902
def test_multiindex_custom_func(func):
902903
# GH 31777
903904
data = [[1, 4, 2], [5, 7, 1]]
904-
df = DataFrame(data, columns=MultiIndex.from_arrays([[1, 1, 2], [3, 4, 3]]))
905+
df = DataFrame(
906+
data,
907+
columns=MultiIndex.from_arrays(
908+
[[1, 1, 2], [3, 4, 3]], names=["Sisko", "Janeway"]
909+
),
910+
)
905911
result = df.groupby(np.array([0, 1])).agg(func)
906912
expected_dict = {
907913
(1, 3): {0: 1.0, 1: 5.0},
908914
(1, 4): {0: 4.0, 1: 7.0},
909915
(2, 3): {0: 2.0, 1: 1.0},
910916
}
911917
expected = DataFrame(expected_dict)
918+
expected.columns = df.columns
912919
tm.assert_frame_equal(result, expected)
913920

914921

pandas/tests/groupby/test_groupby.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -637,10 +637,11 @@ def test_as_index_select_column():
637637

638638
def test_groupby_as_index_select_column_sum_empty_df():
639639
# GH 35246
640-
df = DataFrame(columns=["A", "B", "C"])
640+
df = DataFrame(columns=Index(["A", "B", "C"], name="alpha"))
641641
left = df.groupby(by="A", as_index=False)["B"].sum(numeric_only=False)
642-
assert type(left) is DataFrame
643-
assert left.to_dict() == {"A": {}, "B": {}}
642+
643+
expected = DataFrame(columns=df.columns[:2], index=range(0))
644+
tm.assert_frame_equal(left, expected)
644645

645646

646647
def test_groupby_as_index_agg(df):
@@ -1944,8 +1945,8 @@ def test_groupby_agg_ohlc_non_first():
19441945
# GH 21716
19451946
df = DataFrame(
19461947
[[1], [1]],
1947-
columns=["foo"],
1948-
index=date_range("2018-01-01", periods=2, freq="D"),
1948+
columns=Index(["foo"], name="mycols"),
1949+
index=date_range("2018-01-01", periods=2, freq="D", name="dti"),
19491950
)
19501951

19511952
expected = DataFrame(
@@ -1957,9 +1958,10 @@ def test_groupby_agg_ohlc_non_first():
19571958
("foo", "ohlc", "high"),
19581959
("foo", "ohlc", "low"),
19591960
("foo", "ohlc", "close"),
1960-
)
1961+
),
1962+
names=["mycols", None, None],
19611963
),
1962-
index=date_range("2018-01-01", periods=2, freq="D"),
1964+
index=date_range("2018-01-01", periods=2, freq="D", name="dti"),
19631965
)
19641966

19651967
result = df.groupby(Grouper(freq="D")).agg(["sum", "ohlc"])
@@ -2131,7 +2133,11 @@ def test_groupby_duplicate_index():
21312133

21322134

21332135
@pytest.mark.parametrize(
2134-
"idx", [Index(["a", "a"]), MultiIndex.from_tuples((("a", "a"), ("a", "a")))]
2136+
"idx",
2137+
[
2138+
Index(["a", "a"], name="foo"),
2139+
MultiIndex.from_tuples((("a", "a"), ("a", "a")), names=["foo", "bar"]),
2140+
],
21352141
)
21362142
@pytest.mark.filterwarnings("ignore:tshift is deprecated:FutureWarning")
21372143
def test_dup_labels_output_shape(groupby_func, idx):

0 commit comments

Comments
 (0)