Skip to content

Commit 8b4c4b4

Browse files
authored
REF: de-duplicate result index construction in groupby (#43466)
1 parent abea8fa commit 8b4c4b4

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

pandas/core/groupby/generic.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -431,16 +431,9 @@ def _wrap_applied_output(
431431
)
432432
assert values is not None
433433

434-
def _get_index() -> Index:
435-
if self.grouper.nkeys > 1:
436-
index = MultiIndex.from_tuples(keys, names=self.grouper.names)
437-
else:
438-
index = Index._with_infer(keys, name=self.grouper.names[0])
439-
return index
440-
441434
if isinstance(values[0], dict):
442435
# GH #823 #24880
443-
index = _get_index()
436+
index = self._group_keys_index
444437
res_df = self.obj._constructor_expanddim(values, index=index)
445438
res_df = self._reindex_output(res_df)
446439
# if self.observed is False,
@@ -453,7 +446,7 @@ def _get_index() -> Index:
453446
else:
454447
# GH #6265 #24880
455448
result = self.obj._constructor(
456-
data=values, index=_get_index(), name=self.obj.name
449+
data=values, index=self._group_keys_index, name=self.obj.name
457450
)
458451
return self._reindex_output(result)
459452

pandas/core/groupby/groupby.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,18 @@ def _resolve_numeric_only(self, numeric_only: bool | lib.NoDefault) -> bool:
11771177
# expected "bool")
11781178
return numeric_only # type: ignore[return-value]
11791179

1180+
@cache_readonly
1181+
def _group_keys_index(self) -> Index:
1182+
# The index to use for the result of Groupby Aggregations.
1183+
# This _may_ be redundant with self.grouper.result_index, but that
1184+
# has not been conclusively proven yet.
1185+
keys = self.grouper._get_group_keys()
1186+
if self.grouper.nkeys > 1:
1187+
index = MultiIndex.from_tuples(keys, names=self.grouper.names)
1188+
else:
1189+
index = Index._with_infer(keys, name=self.grouper.names[0])
1190+
return index
1191+
11801192
# -----------------------------------------------------------------
11811193
# numba
11821194

@@ -1244,15 +1256,15 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
12441256
data and indices into a Numba jitted function.
12451257
"""
12461258
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)
1247-
group_keys = self.grouper._get_group_keys()
1259+
index = self._group_keys_index
12481260

12491261
numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs)
12501262
result = numba_agg_func(
12511263
sorted_data,
12521264
sorted_index,
12531265
starts,
12541266
ends,
1255-
len(group_keys),
1267+
len(index),
12561268
len(data.columns),
12571269
*args,
12581270
)
@@ -1261,10 +1273,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
12611273
if cache_key not in NUMBA_FUNC_CACHE:
12621274
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func
12631275

1264-
if self.grouper.nkeys > 1:
1265-
index = MultiIndex.from_tuples(group_keys, names=self.grouper.names)
1266-
else:
1267-
index = Index(group_keys, name=self.grouper.names[0])
12681276
return result, index
12691277

12701278
# -----------------------------------------------------------------

0 commit comments

Comments
 (0)