diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 002d8640f109d..06dc1e2c4fa51 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -10,19 +10,7 @@ from functools import partial from textwrap import dedent import typing -from typing import ( - Any, - Callable, - FrozenSet, - Hashable, - Iterable, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, -) +from typing import Any, Callable, FrozenSet, Iterable, Sequence, Type, Union, cast import warnings import numpy as np @@ -142,8 +130,8 @@ def pinner(cls): class SeriesGroupBy(GroupBy): _apply_whitelist = base.series_apply_whitelist - def _iterate_slices(self) -> Iterable[Tuple[Optional[Hashable], Series]]: - yield self._selection_name, self._selected_obj + def _iterate_slices(self) -> Iterable[Series]: + yield self._selected_obj @property def _selection_name(self): @@ -923,20 +911,20 @@ def aggregate(self, func=None, *args, **kwargs): agg = aggregate - def _iterate_slices(self) -> Iterable[Tuple[Optional[Hashable], Series]]: + def _iterate_slices(self) -> Iterable[Series]: obj = self._selected_obj if self.axis == 1: obj = obj.T if isinstance(obj, Series) and obj.name not in self.exclusions: # Occurs when doing DataFrameGroupBy(...)["X"] - yield obj.name, obj + yield obj else: for label, values in obj.items(): if label in self.exclusions: continue - yield label, values + yield values def _cython_agg_general( self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1 diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 280f1e88b0ea8..cc538b291ed9a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -14,7 +14,7 @@ class providing the base-class of operations. import inspect import re import types -from typing import FrozenSet, Hashable, Iterable, List, Optional, Tuple, Type, Union +from typing import FrozenSet, Iterable, List, Optional, Tuple, Type, Union import numpy as np @@ -439,7 +439,7 @@ def _get_indices(self, names): def get_converter(s): # possibly convert to the actual key types # in the indices, could be a Timestamp or a np.datetime64 - if isinstance(s, (Timestamp, datetime.datetime)): + if isinstance(s, datetime.datetime): return lambda key: Timestamp(key) elif isinstance(s, np.datetime64): return lambda key: Timestamp(key).asm8 @@ -488,6 +488,7 @@ def _get_index(self, name): @cache_readonly def _selected_obj(self): + # Note: _selected_obj is always just `self.obj` for SeriesGroupBy if self._selection is None or isinstance(self.obj, Series): if self._group_selection is not None: @@ -736,7 +737,7 @@ def _python_apply_general(self, f): keys, values, not_indexed_same=mutated or self.mutated ) - def _iterate_slices(self) -> Iterable[Tuple[Optional[Hashable], Series]]: + def _iterate_slices(self) -> Iterable[Series]: raise AbstractMethodError(self) def transform(self, func, *args, **kwargs): @@ -832,7 +833,8 @@ def _transform_should_cast(self, func_nm: str) -> bool: def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs): output = collections.OrderedDict() # type: dict - for name, obj in self._iterate_slices(): + for obj in self._iterate_slices(): + name = obj.name is_numeric = is_numeric_dtype(obj.dtype) if numeric_only and not is_numeric: continue @@ -864,7 +866,8 @@ def _cython_agg_general( self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1 ): output = {} - for name, obj in self._iterate_slices(): + for obj in self._iterate_slices(): + name = obj.name is_numeric = is_numeric_dtype(obj.dtype) if numeric_only and not is_numeric: continue @@ -883,7 +886,8 @@ def _python_agg_general(self, func, *args, **kwargs): # iterate through "columns" ex exclusions to populate output dict output = {} - for name, obj in self._iterate_slices(): + for obj in self._iterate_slices(): + name = obj.name if self.grouper.ngroups == 0: # agg_series below assumes ngroups > 0 continue @@ -2242,7 +2246,8 @@ def _get_cythonized_result( output = collections.OrderedDict() # type: dict base_func = getattr(libgroupby, how) - for name, obj in self._iterate_slices(): + for obj in self._iterate_slices(): + name = obj.name values = obj._data._values if aggregate: