Skip to content

Commit 5d11658

Browse files
authored
REF: groupby Series selection with as_index=False (#50744)
* REF: groupby Series selection with as_index=False * GH# * type-hinting fixes
1 parent e9b3562 commit 5d11658

File tree

7 files changed

+161
-92
lines changed

7 files changed

+161
-92
lines changed

pandas/core/apply.py

+55-26
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
import abc
44
from collections import defaultdict
5+
from contextlib import nullcontext
56
from functools import partial
67
import inspect
78
from typing import (
89
TYPE_CHECKING,
910
Any,
1011
Callable,
12+
ContextManager,
1113
DefaultDict,
1214
Dict,
1315
Hashable,
@@ -292,6 +294,10 @@ def agg_list_like(self) -> DataFrame | Series:
292294
-------
293295
Result of aggregation.
294296
"""
297+
from pandas.core.groupby.generic import (
298+
DataFrameGroupBy,
299+
SeriesGroupBy,
300+
)
295301
from pandas.core.reshape.concat import concat
296302

297303
obj = self.obj
@@ -312,26 +318,36 @@ def agg_list_like(self) -> DataFrame | Series:
312318
results = []
313319
keys = []
314320

315-
# degenerate case
316-
if selected_obj.ndim == 1:
317-
for a in arg:
318-
colg = obj._gotitem(selected_obj.name, ndim=1, subset=selected_obj)
319-
new_res = colg.aggregate(a)
320-
results.append(new_res)
321+
is_groupby = isinstance(obj, (DataFrameGroupBy, SeriesGroupBy))
322+
context_manager: ContextManager
323+
if is_groupby:
324+
# When as_index=False, we combine all results using indices
325+
# and adjust index after
326+
context_manager = com.temp_setattr(obj, "as_index", True)
327+
else:
328+
context_manager = nullcontext()
329+
with context_manager:
330+
# degenerate case
331+
if selected_obj.ndim == 1:
321332

322-
# make sure we find a good name
323-
name = com.get_callable_name(a) or a
324-
keys.append(name)
333+
for a in arg:
334+
colg = obj._gotitem(selected_obj.name, ndim=1, subset=selected_obj)
335+
new_res = colg.aggregate(a)
336+
results.append(new_res)
325337

326-
# multiples
327-
else:
328-
indices = []
329-
for index, col in enumerate(selected_obj):
330-
colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index])
331-
new_res = colg.aggregate(arg)
332-
results.append(new_res)
333-
indices.append(index)
334-
keys = selected_obj.columns.take(indices)
338+
# make sure we find a good name
339+
name = com.get_callable_name(a) or a
340+
keys.append(name)
341+
342+
# multiples
343+
else:
344+
indices = []
345+
for index, col in enumerate(selected_obj):
346+
colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index])
347+
new_res = colg.aggregate(arg)
348+
results.append(new_res)
349+
indices.append(index)
350+
keys = selected_obj.columns.take(indices)
335351

336352
try:
337353
concatenated = concat(results, keys=keys, axis=1, sort=False)
@@ -366,6 +382,10 @@ def agg_dict_like(self) -> DataFrame | Series:
366382
Result of aggregation.
367383
"""
368384
from pandas import Index
385+
from pandas.core.groupby.generic import (
386+
DataFrameGroupBy,
387+
SeriesGroupBy,
388+
)
369389
from pandas.core.reshape.concat import concat
370390

371391
obj = self.obj
@@ -384,15 +404,24 @@ def agg_dict_like(self) -> DataFrame | Series:
384404

385405
arg = self.normalize_dictlike_arg("agg", selected_obj, arg)
386406

387-
if selected_obj.ndim == 1:
388-
# key only used for output
389-
colg = obj._gotitem(selection, ndim=1)
390-
results = {key: colg.agg(how) for key, how in arg.items()}
407+
is_groupby = isinstance(obj, (DataFrameGroupBy, SeriesGroupBy))
408+
context_manager: ContextManager
409+
if is_groupby:
410+
# When as_index=False, we combine all results using indices
411+
# and adjust index after
412+
context_manager = com.temp_setattr(obj, "as_index", True)
391413
else:
392-
# key used for column selection and output
393-
results = {
394-
key: obj._gotitem(key, ndim=1).agg(how) for key, how in arg.items()
395-
}
414+
context_manager = nullcontext()
415+
with context_manager:
416+
if selected_obj.ndim == 1:
417+
# key only used for output
418+
colg = obj._gotitem(selection, ndim=1)
419+
results = {key: colg.agg(how) for key, how in arg.items()}
420+
else:
421+
# key used for column selection and output
422+
results = {
423+
key: obj._gotitem(key, ndim=1).agg(how) for key, how in arg.items()
424+
}
396425

397426
# set the final keys
398427
keys = list(arg.keys())

pandas/core/base.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def _obj_with_exclusions(self):
216216
if self._selection is not None and isinstance(self.obj, ABCDataFrame):
217217
return self.obj[self._selection_list]
218218

219+
if isinstance(self.obj, ABCSeries):
220+
return self.obj
221+
219222
if len(self.exclusions) > 0:
220223
# equivalent to `self.obj.drop(self.exclusions, axis=1)
221224
# but this avoids consolidating and making a copy
@@ -235,17 +238,11 @@ def __getitem__(self, key):
235238
raise KeyError(f"Columns not found: {str(bad_keys)[1:-1]}")
236239
return self._gotitem(list(key), ndim=2)
237240

238-
elif not getattr(self, "as_index", False):
239-
if key not in self.obj.columns:
240-
raise KeyError(f"Column not found: {key}")
241-
return self._gotitem(key, ndim=2)
242-
243241
else:
244242
if key not in self.obj:
245243
raise KeyError(f"Column not found: {key}")
246-
subset = self.obj[key]
247-
ndim = subset.ndim
248-
return self._gotitem(key, ndim=ndim, subset=subset)
244+
ndim = self.obj[key].ndim
245+
return self._gotitem(key, ndim=ndim)
249246

250247
def _gotitem(self, key, ndim: int, subset=None):
251248
"""

pandas/core/groupby/generic.py

+52-38
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,11 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
248248
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
249249
)
250250
index = self.grouper.result_index
251-
return self.obj._constructor(result.ravel(), index=index, name=data.name)
251+
result = self.obj._constructor(result.ravel(), index=index, name=data.name)
252+
if not self.as_index:
253+
result = self._insert_inaxis_grouper(result)
254+
result.index = default_index(len(result))
255+
return result
252256

253257
relabeling = func is None
254258
columns = None
@@ -268,6 +272,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
268272
# columns is not narrowed by mypy from relabeling flag
269273
assert columns is not None # for mypy
270274
ret.columns = columns
275+
if not self.as_index:
276+
ret = self._insert_inaxis_grouper(ret)
277+
ret.index = default_index(len(ret))
271278
return ret
272279

273280
else:
@@ -287,23 +294,24 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
287294

288295
# result is a dict whose keys are the elements of result_index
289296
index = self.grouper.result_index
290-
return Series(result, index=index)
297+
result = Series(result, index=index)
298+
if not self.as_index:
299+
result = self._insert_inaxis_grouper(result)
300+
result.index = default_index(len(result))
301+
return result
291302

292303
agg = aggregate
293304

294305
def _aggregate_multiple_funcs(self, arg) -> DataFrame:
295306
if isinstance(arg, dict):
296-
297-
# show the deprecation, but only if we
298-
# have not shown a higher level one
299-
# GH 15931
300-
raise SpecificationError("nested renamer is not supported")
301-
302-
if any(isinstance(x, (tuple, list)) for x in arg):
307+
if self.as_index:
308+
# GH 15931
309+
raise SpecificationError("nested renamer is not supported")
310+
else:
311+
# GH#50684 - This accidentally worked in 1.x
312+
arg = list(arg.items())
313+
elif any(isinstance(x, (tuple, list)) for x in arg):
303314
arg = [(x, x) if not isinstance(x, (tuple, list)) else x for x in arg]
304-
305-
# indicated column order
306-
columns = next(zip(*arg))
307315
else:
308316
# list of functions / function names
309317
columns = []
@@ -313,10 +321,13 @@ def _aggregate_multiple_funcs(self, arg) -> DataFrame:
313321
arg = zip(columns, arg)
314322

315323
results: dict[base.OutputKey, DataFrame | Series] = {}
316-
for idx, (name, func) in enumerate(arg):
324+
with com.temp_setattr(self, "as_index", True):
325+
# Combine results using the index, need to adjust index after
326+
# if as_index=False (GH#50724)
327+
for idx, (name, func) in enumerate(arg):
317328

318-
key = base.OutputKey(label=name, position=idx)
319-
results[key] = self.aggregate(func)
329+
key = base.OutputKey(label=name, position=idx)
330+
results[key] = self.aggregate(func)
320331

321332
if any(isinstance(x, DataFrame) for x in results.values()):
322333
from pandas import concat
@@ -396,12 +407,18 @@ def _wrap_applied_output(
396407
)
397408
if isinstance(result, Series):
398409
result.name = self.obj.name
410+
if not self.as_index and not_indexed_same:
411+
result = self._insert_inaxis_grouper(result)
412+
result.index = default_index(len(result))
399413
return result
400414
else:
401415
# GH #6265 #24880
402416
result = self.obj._constructor(
403417
data=values, index=self.grouper.result_index, name=self.obj.name
404418
)
419+
if not self.as_index:
420+
result = self._insert_inaxis_grouper(result)
421+
result.index = default_index(len(result))
405422
return self._reindex_output(result)
406423

407424
def _aggregate_named(self, func, *args, **kwargs):
@@ -577,7 +594,7 @@ def true_and_notna(x) -> bool:
577594
filtered = self._apply_filter(indices, dropna)
578595
return filtered
579596

580-
def nunique(self, dropna: bool = True) -> Series:
597+
def nunique(self, dropna: bool = True) -> Series | DataFrame:
581598
"""
582599
Return number of unique elements in the group.
583600
@@ -629,7 +646,12 @@ def nunique(self, dropna: bool = True) -> Series:
629646
# GH#21334s
630647
res[ids[idx]] = out
631648

632-
result = self.obj._constructor(res, index=ri, name=self.obj.name)
649+
result: Series | DataFrame = self.obj._constructor(
650+
res, index=ri, name=self.obj.name
651+
)
652+
if not self.as_index:
653+
result = self._insert_inaxis_grouper(result)
654+
result.index = default_index(len(result))
633655
return self._reindex_output(result, fill_value=0)
634656

635657
@doc(Series.describe)
@@ -643,12 +665,11 @@ def value_counts(
643665
ascending: bool = False,
644666
bins=None,
645667
dropna: bool = True,
646-
) -> Series:
668+
) -> Series | DataFrame:
647669
if bins is None:
648670
result = self._value_counts(
649671
normalize=normalize, sort=sort, ascending=ascending, dropna=dropna
650672
)
651-
assert isinstance(result, Series)
652673
return result
653674

654675
from pandas.core.reshape.merge import get_join_indexers
@@ -786,7 +807,11 @@ def build_codes(lev_codes: np.ndarray) -> np.ndarray:
786807

787808
if is_integer_dtype(out.dtype):
788809
out = ensure_int64(out)
789-
return self.obj._constructor(out, index=mi, name=self.obj.name)
810+
result = self.obj._constructor(out, index=mi, name=self.obj.name)
811+
if not self.as_index:
812+
result.name = "proportion" if normalize else "count"
813+
result = result.reset_index()
814+
return result
790815

791816
def fillna(
792817
self,
@@ -1274,7 +1299,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
12741299
result.columns = result.columns.droplevel(-1)
12751300

12761301
if not self.as_index:
1277-
self._insert_inaxis_grouper_inplace(result)
1302+
result = self._insert_inaxis_grouper(result)
12781303
result.index = default_index(len(result))
12791304

12801305
return result
@@ -1386,7 +1411,7 @@ def _wrap_applied_output(
13861411
return self.obj._constructor_sliced(values, index=key_index)
13871412
else:
13881413
result = self.obj._constructor(values, columns=[self._selection])
1389-
self._insert_inaxis_grouper_inplace(result)
1414+
result = self._insert_inaxis_grouper(result)
13901415
return result
13911416
else:
13921417
# values are Series
@@ -1443,7 +1468,7 @@ def _wrap_applied_output_series(
14431468
result = self.obj._constructor(stacked_values, index=index, columns=columns)
14441469

14451470
if not self.as_index:
1446-
self._insert_inaxis_grouper_inplace(result)
1471+
result = self._insert_inaxis_grouper(result)
14471472

14481473
return self._reindex_output(result)
14491474

@@ -1774,7 +1799,9 @@ def _gotitem(self, key, ndim: int, subset=None):
17741799
subset,
17751800
level=self.level,
17761801
grouper=self.grouper,
1802+
exclusions=self.exclusions,
17771803
selection=key,
1804+
as_index=self.as_index,
17781805
sort=self.sort,
17791806
group_keys=self.group_keys,
17801807
observed=self.observed,
@@ -1790,19 +1817,6 @@ def _get_data_to_aggregate(self) -> Manager2D:
17901817
else:
17911818
return obj._mgr
17921819

1793-
def _insert_inaxis_grouper_inplace(self, result: DataFrame) -> None:
1794-
# zip in reverse so we can always insert at loc 0
1795-
columns = result.columns
1796-
for name, lev, in_axis in zip(
1797-
reversed(self.grouper.names),
1798-
reversed(self.grouper.get_group_levels()),
1799-
reversed([grp.in_axis for grp in self.grouper.groupings]),
1800-
):
1801-
# GH #28549
1802-
# When using .apply(-), name will be in columns already
1803-
if in_axis and name not in columns:
1804-
result.insert(0, name, lev)
1805-
18061820
def _indexed_output_to_ndframe(
18071821
self, output: Mapping[base.OutputKey, ArrayLike]
18081822
) -> DataFrame:
@@ -1825,7 +1839,7 @@ def _wrap_agged_manager(self, mgr: Manager2D) -> DataFrame:
18251839
mgr.set_axis(1, index)
18261840
result = self.obj._constructor(mgr)
18271841

1828-
self._insert_inaxis_grouper_inplace(result)
1842+
result = self._insert_inaxis_grouper(result)
18291843
result = result._consolidate()
18301844
else:
18311845
index = self.grouper.result_index
@@ -1918,7 +1932,7 @@ def nunique(self, dropna: bool = True) -> DataFrame:
19181932

19191933
if not self.as_index:
19201934
results.index = default_index(len(results))
1921-
self._insert_inaxis_grouper_inplace(results)
1935+
results = self._insert_inaxis_grouper(results)
19221936

19231937
return results
19241938

0 commit comments

Comments
 (0)