Skip to content

Commit 4b3027f

Browse files
jbrockmendelWillAyd
authored andcommitted
CLN: type annotations in groupby.grouper, groupby.ops (#29456)
* Annotate groupby.ops * annotations, needs debugging * whitespace * types * circular import * fix msot mypy complaints * fix mypy groupings * merge cleanup
1 parent 57e1b34 commit 4b3027f

File tree

2 files changed

+42
-34
lines changed

2 files changed

+42
-34
lines changed

pandas/core/groupby/grouper.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __init__(self, key=None, level=None, freq=None, axis=0, sort=False):
119119
def ax(self):
120120
return self.grouper
121121

122-
def _get_grouper(self, obj, validate=True):
122+
def _get_grouper(self, obj, validate: bool = True):
123123
"""
124124
Parameters
125125
----------
@@ -143,17 +143,18 @@ def _get_grouper(self, obj, validate=True):
143143
)
144144
return self.binner, self.grouper, self.obj
145145

146-
def _set_grouper(self, obj, sort=False):
146+
def _set_grouper(self, obj: FrameOrSeries, sort: bool = False):
147147
"""
148148
given an object and the specifications, setup the internal grouper
149149
for this particular specification
150150
151151
Parameters
152152
----------
153-
obj : the subject object
153+
obj : Series or DataFrame
154154
sort : bool, default False
155155
whether the resulting grouper should be sorted
156156
"""
157+
assert obj is not None
157158

158159
if self.key is not None and self.level is not None:
159160
raise ValueError("The Grouper cannot specify both a key and a level!")
@@ -211,13 +212,13 @@ def groups(self):
211212

212213
def __repr__(self) -> str:
213214
attrs_list = (
214-
"{}={!r}".format(attr_name, getattr(self, attr_name))
215+
"{name}={val!r}".format(name=attr_name, val=getattr(self, attr_name))
215216
for attr_name in self._attributes
216217
if getattr(self, attr_name) is not None
217218
)
218219
attrs = ", ".join(attrs_list)
219220
cls_name = self.__class__.__name__
220-
return "{}({})".format(cls_name, attrs)
221+
return "{cls}({attrs})".format(cls=cls_name, attrs=attrs)
221222

222223

223224
class Grouping:
@@ -372,7 +373,7 @@ def __init__(
372373
self.grouper = self.grouper.astype("timedelta64[ns]")
373374

374375
def __repr__(self) -> str:
375-
return "Grouping({0})".format(self.name)
376+
return "Grouping({name})".format(name=self.name)
376377

377378
def __iter__(self):
378379
return iter(self.indices)
@@ -433,10 +434,10 @@ def get_grouper(
433434
key=None,
434435
axis: int = 0,
435436
level=None,
436-
sort=True,
437-
observed=False,
438-
mutated=False,
439-
validate=True,
437+
sort: bool = True,
438+
observed: bool = False,
439+
mutated: bool = False,
440+
validate: bool = True,
440441
) -> Tuple[BaseGrouper, List[Hashable], FrameOrSeries]:
441442
"""
442443
Create and return a BaseGrouper, which is an internal
@@ -670,7 +671,7 @@ def is_in_obj(gpr) -> bool:
670671
return grouper, exclusions, obj
671672

672673

673-
def _is_label_like(val):
674+
def _is_label_like(val) -> bool:
674675
return isinstance(val, (str, tuple)) or (val is not None and is_scalar(val))
675676

676677

pandas/core/groupby/ops.py

+30-23
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from pandas.core.dtypes.missing import _maybe_fill, isna
3838

39+
from pandas._typing import FrameOrSeries
3940
import pandas.core.algorithms as algorithms
4041
from pandas.core.base import SelectionMixin
4142
import pandas.core.common as com
@@ -89,12 +90,16 @@ def __init__(
8990

9091
self._filter_empty_groups = self.compressed = len(groupings) != 1
9192
self.axis = axis
92-
self.groupings = groupings # type: Sequence[grouper.Grouping]
93+
self._groupings = list(groupings) # type: List[grouper.Grouping]
9394
self.sort = sort
9495
self.group_keys = group_keys
9596
self.mutated = mutated
9697
self.indexer = indexer
9798

99+
@property
100+
def groupings(self) -> List["grouper.Grouping"]:
101+
return self._groupings
102+
98103
@property
99104
def shape(self):
100105
return tuple(ping.ngroups for ping in self.groupings)
@@ -106,7 +111,7 @@ def __iter__(self):
106111
def nkeys(self) -> int:
107112
return len(self.groupings)
108113

109-
def get_iterator(self, data, axis=0):
114+
def get_iterator(self, data: FrameOrSeries, axis: int = 0):
110115
"""
111116
Groupby iterator
112117
@@ -120,7 +125,7 @@ def get_iterator(self, data, axis=0):
120125
for key, (i, group) in zip(keys, splitter):
121126
yield key, group
122127

123-
def _get_splitter(self, data, axis=0):
128+
def _get_splitter(self, data: FrameOrSeries, axis: int = 0) -> "DataSplitter":
124129
comp_ids, _, ngroups = self.group_info
125130
return get_splitter(data, comp_ids, ngroups, axis=axis)
126131

@@ -142,13 +147,13 @@ def _get_group_keys(self):
142147
# provide "flattened" iterator for multi-group setting
143148
return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes)
144149

145-
def apply(self, f, data, axis: int = 0):
150+
def apply(self, f, data: FrameOrSeries, axis: int = 0):
146151
mutated = self.mutated
147152
splitter = self._get_splitter(data, axis=axis)
148153
group_keys = self._get_group_keys()
149154
result_values = None
150155

151-
sdata = splitter._get_sorted_data()
156+
sdata = splitter._get_sorted_data() # type: FrameOrSeries
152157
if sdata.ndim == 2 and np.any(sdata.dtypes.apply(is_extension_array_dtype)):
153158
# calling splitter.fast_apply will raise TypeError via apply_frame_axis0
154159
# if we pass EA instead of ndarray
@@ -157,7 +162,7 @@ def apply(self, f, data, axis: int = 0):
157162

158163
elif (
159164
com.get_callable_name(f) not in base.plotting_methods
160-
and hasattr(splitter, "fast_apply")
165+
and isinstance(splitter, FrameSplitter)
161166
and axis == 0
162167
# with MultiIndex, apply_frame_axis0 would raise InvalidApply
163168
# TODO: can we make this check prettier?
@@ -229,8 +234,7 @@ def names(self):
229234

230235
def size(self) -> Series:
231236
"""
232-
Compute group sizes
233-
237+
Compute group sizes.
234238
"""
235239
ids, _, ngroup = self.group_info
236240
ids = ensure_platform_int(ids)
@@ -292,7 +296,7 @@ def reconstructed_codes(self) -> List[np.ndarray]:
292296
return decons_obs_group_ids(comp_ids, obs_ids, self.shape, codes, xnull=True)
293297

294298
@cache_readonly
295-
def result_index(self):
299+
def result_index(self) -> Index:
296300
if not self.compressed and len(self.groupings) == 1:
297301
return self.groupings[0].result_index.rename(self.names[0])
298302

@@ -629,7 +633,7 @@ def agg_series(self, obj: Series, func):
629633
raise
630634
return self._aggregate_series_pure_python(obj, func)
631635

632-
def _aggregate_series_fast(self, obj, func):
636+
def _aggregate_series_fast(self, obj: Series, func):
633637
# At this point we have already checked that
634638
# - obj.index is not a MultiIndex
635639
# - obj is backed by an ndarray, not ExtensionArray
@@ -648,7 +652,7 @@ def _aggregate_series_fast(self, obj, func):
648652
result, counts = grouper.get_result()
649653
return result, counts
650654

651-
def _aggregate_series_pure_python(self, obj, func):
655+
def _aggregate_series_pure_python(self, obj: Series, func):
652656

653657
group_index, _, ngroups = self.group_info
654658

@@ -705,7 +709,12 @@ class BinGrouper(BaseGrouper):
705709
"""
706710

707711
def __init__(
708-
self, bins, binlabels, filter_empty=False, mutated=False, indexer=None
712+
self,
713+
bins,
714+
binlabels,
715+
filter_empty: bool = False,
716+
mutated: bool = False,
717+
indexer=None,
709718
):
710719
self.bins = ensure_int64(bins)
711720
self.binlabels = ensure_index(binlabels)
@@ -739,7 +748,7 @@ def _get_grouper(self):
739748
"""
740749
return self
741750

742-
def get_iterator(self, data: NDFrame, axis: int = 0):
751+
def get_iterator(self, data: FrameOrSeries, axis: int = 0):
743752
"""
744753
Groupby iterator
745754
@@ -811,11 +820,9 @@ def names(self):
811820
return [self.binlabels.name]
812821

813822
@property
814-
def groupings(self):
815-
from pandas.core.groupby.grouper import Grouping
816-
823+
def groupings(self) -> "List[grouper.Grouping]":
817824
return [
818-
Grouping(lvl, lvl, in_axis=False, level=None, name=name)
825+
grouper.Grouping(lvl, lvl, in_axis=False, level=None, name=name)
819826
for lvl, name in zip(self.levels, self.names)
820827
]
821828

@@ -856,7 +863,7 @@ def _is_indexed_like(obj, axes) -> bool:
856863

857864

858865
class DataSplitter:
859-
def __init__(self, data, labels, ngroups, axis: int = 0):
866+
def __init__(self, data: FrameOrSeries, labels, ngroups: int, axis: int = 0):
860867
self.data = data
861868
self.labels = ensure_int64(labels)
862869
self.ngroups = ngroups
@@ -887,15 +894,15 @@ def __iter__(self):
887894
for i, (start, end) in enumerate(zip(starts, ends)):
888895
yield i, self._chop(sdata, slice(start, end))
889896

890-
def _get_sorted_data(self):
897+
def _get_sorted_data(self) -> FrameOrSeries:
891898
return self.data.take(self.sort_idx, axis=self.axis)
892899

893-
def _chop(self, sdata, slice_obj: slice):
900+
def _chop(self, sdata, slice_obj: slice) -> NDFrame:
894901
raise AbstractMethodError(self)
895902

896903

897904
class SeriesSplitter(DataSplitter):
898-
def _chop(self, sdata, slice_obj: slice):
905+
def _chop(self, sdata: Series, slice_obj: slice) -> Series:
899906
return sdata._get_values(slice_obj)
900907

901908

@@ -907,14 +914,14 @@ def fast_apply(self, f, names):
907914
sdata = self._get_sorted_data()
908915
return libreduction.apply_frame_axis0(sdata, f, names, starts, ends)
909916

910-
def _chop(self, sdata, slice_obj: slice):
917+
def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
911918
if self.axis == 0:
912919
return sdata.iloc[slice_obj]
913920
else:
914921
return sdata._slice(slice_obj, axis=1)
915922

916923

917-
def get_splitter(data: NDFrame, *args, **kwargs):
924+
def get_splitter(data: FrameOrSeries, *args, **kwargs) -> DataSplitter:
918925
if isinstance(data, Series):
919926
klass = SeriesSplitter # type: Type[DataSplitter]
920927
else:

0 commit comments

Comments
 (0)