Skip to content

Commit 29be383

Browse files
jbrockmendeljreback
authored andcommitted
REF: simplify _iterate_slices (#29629)
1 parent ded50fd commit 29be383

File tree

2 files changed

+18
-25
lines changed

2 files changed

+18
-25
lines changed

pandas/core/groupby/generic.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,7 @@
1010
from functools import partial
1111
from textwrap import dedent
1212
import typing
13-
from typing import (
14-
Any,
15-
Callable,
16-
FrozenSet,
17-
Hashable,
18-
Iterable,
19-
Optional,
20-
Sequence,
21-
Tuple,
22-
Type,
23-
Union,
24-
cast,
25-
)
13+
from typing import Any, Callable, FrozenSet, Iterable, Sequence, Type, Union, cast
2614
import warnings
2715

2816
import numpy as np
@@ -142,8 +130,8 @@ def pinner(cls):
142130
class SeriesGroupBy(GroupBy):
143131
_apply_whitelist = base.series_apply_whitelist
144132

145-
def _iterate_slices(self) -> Iterable[Tuple[Optional[Hashable], Series]]:
146-
yield self._selection_name, self._selected_obj
133+
def _iterate_slices(self) -> Iterable[Series]:
134+
yield self._selected_obj
147135

148136
@property
149137
def _selection_name(self):
@@ -923,20 +911,20 @@ def aggregate(self, func=None, *args, **kwargs):
923911

924912
agg = aggregate
925913

926-
def _iterate_slices(self) -> Iterable[Tuple[Optional[Hashable], Series]]:
914+
def _iterate_slices(self) -> Iterable[Series]:
927915
obj = self._selected_obj
928916
if self.axis == 1:
929917
obj = obj.T
930918

931919
if isinstance(obj, Series) and obj.name not in self.exclusions:
932920
# Occurs when doing DataFrameGroupBy(...)["X"]
933-
yield obj.name, obj
921+
yield obj
934922
else:
935923
for label, values in obj.items():
936924
if label in self.exclusions:
937925
continue
938926

939-
yield label, values
927+
yield values
940928

941929
def _cython_agg_general(
942930
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1

pandas/core/groupby/groupby.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class providing the base-class of operations.
1414
import inspect
1515
import re
1616
import types
17-
from typing import FrozenSet, Hashable, Iterable, List, Optional, Tuple, Type, Union
17+
from typing import FrozenSet, Iterable, List, Optional, Tuple, Type, Union
1818

1919
import numpy as np
2020

@@ -439,7 +439,7 @@ def _get_indices(self, names):
439439
def get_converter(s):
440440
# possibly convert to the actual key types
441441
# in the indices, could be a Timestamp or a np.datetime64
442-
if isinstance(s, (Timestamp, datetime.datetime)):
442+
if isinstance(s, datetime.datetime):
443443
return lambda key: Timestamp(key)
444444
elif isinstance(s, np.datetime64):
445445
return lambda key: Timestamp(key).asm8
@@ -488,6 +488,7 @@ def _get_index(self, name):
488488

489489
@cache_readonly
490490
def _selected_obj(self):
491+
# Note: _selected_obj is always just `self.obj` for SeriesGroupBy
491492

492493
if self._selection is None or isinstance(self.obj, Series):
493494
if self._group_selection is not None:
@@ -736,7 +737,7 @@ def _python_apply_general(self, f):
736737
keys, values, not_indexed_same=mutated or self.mutated
737738
)
738739

739-
def _iterate_slices(self) -> Iterable[Tuple[Optional[Hashable], Series]]:
740+
def _iterate_slices(self) -> Iterable[Series]:
740741
raise AbstractMethodError(self)
741742

742743
def transform(self, func, *args, **kwargs):
@@ -832,7 +833,8 @@ def _transform_should_cast(self, func_nm: str) -> bool:
832833

833834
def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
834835
output = collections.OrderedDict() # type: dict
835-
for name, obj in self._iterate_slices():
836+
for obj in self._iterate_slices():
837+
name = obj.name
836838
is_numeric = is_numeric_dtype(obj.dtype)
837839
if numeric_only and not is_numeric:
838840
continue
@@ -864,7 +866,8 @@ def _cython_agg_general(
864866
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
865867
):
866868
output = {}
867-
for name, obj in self._iterate_slices():
869+
for obj in self._iterate_slices():
870+
name = obj.name
868871
is_numeric = is_numeric_dtype(obj.dtype)
869872
if numeric_only and not is_numeric:
870873
continue
@@ -883,7 +886,8 @@ def _python_agg_general(self, func, *args, **kwargs):
883886

884887
# iterate through "columns" ex exclusions to populate output dict
885888
output = {}
886-
for name, obj in self._iterate_slices():
889+
for obj in self._iterate_slices():
890+
name = obj.name
887891
if self.grouper.ngroups == 0:
888892
# agg_series below assumes ngroups > 0
889893
continue
@@ -2234,7 +2238,8 @@ def _get_cythonized_result(
22342238
output = collections.OrderedDict() # type: dict
22352239
base_func = getattr(libgroupby, how)
22362240

2237-
for name, obj in self._iterate_slices():
2241+
for obj in self._iterate_slices():
2242+
name = obj.name
22382243
values = obj._data._values
22392244

22402245
if aggregate:

0 commit comments

Comments
 (0)