Skip to content

Commit f683473

Browse files
dsaxtonWillAyd
andauthored
TYP: Annotate groupby/ops.py (#32921)
* TYP: Annotate groupby/ops.py * Blacken * Update pandas/core/groupby/ops.py Co-Authored-By: William Ayd <[email protected]> * Use ellipsis * List -> List[Index] * Specify Callable types * More Callable subscripts * Update * No ArrayLike * Import * Update * Use F * Lint Co-authored-by: William Ayd <[email protected]>
1 parent 22cf0f5 commit f683473

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

pandas/core/groupby/grouper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class Grouping:
257257
index : Index
258258
grouper :
259259
obj Union[DataFrame, Series]:
260-
name :
260+
name : Label
261261
level :
262262
observed : bool, default False
263263
If we are a Categorical, use the observed values

pandas/core/groupby/ops.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pandas._libs import NaT, iNaT, lib
1515
import pandas._libs.groupby as libgroupby
1616
import pandas._libs.reduction as libreduction
17-
from pandas._typing import FrameOrSeries
17+
from pandas._typing import F, FrameOrSeries, Label
1818
from pandas.errors import AbstractMethodError
1919
from pandas.util._decorators import cache_readonly
2020

@@ -110,7 +110,7 @@ def groupings(self) -> List["grouper.Grouping"]:
110110
return self._groupings
111111

112112
@property
113-
def shape(self):
113+
def shape(self) -> Tuple[int, ...]:
114114
return tuple(ping.ngroups for ping in self.groupings)
115115

116116
def __iter__(self):
@@ -156,7 +156,7 @@ def _get_group_keys(self):
156156
# provide "flattened" iterator for multi-group setting
157157
return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes)
158158

159-
def apply(self, f, data: FrameOrSeries, axis: int = 0):
159+
def apply(self, f: F, data: FrameOrSeries, axis: int = 0):
160160
mutated = self.mutated
161161
splitter = self._get_splitter(data, axis=axis)
162162
group_keys = self._get_group_keys()
@@ -237,7 +237,7 @@ def levels(self) -> List[Index]:
237237
return [ping.group_index for ping in self.groupings]
238238

239239
@property
240-
def names(self):
240+
def names(self) -> List[Label]:
241241
return [ping.name for ping in self.groupings]
242242

243243
def size(self) -> Series:
@@ -315,7 +315,7 @@ def result_index(self) -> Index:
315315
)
316316
return result
317317

318-
def get_group_levels(self):
318+
def get_group_levels(self) -> List[Index]:
319319
if not self.compressed and len(self.groupings) == 1:
320320
return [self.groupings[0].result_index]
321321

@@ -364,7 +364,9 @@ def _is_builtin_func(self, arg):
364364
"""
365365
return SelectionMixin._builtin_table.get(arg, arg)
366366

367-
def _get_cython_function(self, kind: str, how: str, values, is_numeric: bool):
367+
def _get_cython_function(
368+
self, kind: str, how: str, values: np.ndarray, is_numeric: bool
369+
):
368370

369371
dtype_str = values.dtype.name
370372
ftype = self._cython_functions[kind][how]
@@ -433,7 +435,7 @@ def _get_cython_func_and_vals(
433435
return func, values
434436

435437
def _cython_operation(
436-
self, kind: str, values, how: str, axis, min_count: int = -1, **kwargs
438+
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
437439
) -> Tuple[np.ndarray, Optional[List[str]]]:
438440
"""
439441
Returns the values of a cython operation as a Tuple of [data, names].
@@ -617,7 +619,13 @@ def _transform(
617619
return result
618620

619621
def agg_series(
620-
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
622+
self,
623+
obj: Series,
624+
func: F,
625+
*args,
626+
engine: str = "cython",
627+
engine_kwargs=None,
628+
**kwargs,
621629
):
622630
# Caller is responsible for checking ngroups != 0
623631
assert self.ngroups != 0
@@ -651,7 +659,7 @@ def agg_series(
651659
raise
652660
return self._aggregate_series_pure_python(obj, func)
653661

654-
def _aggregate_series_fast(self, obj: Series, func):
662+
def _aggregate_series_fast(self, obj: Series, func: F):
655663
# At this point we have already checked that
656664
# - obj.index is not a MultiIndex
657665
# - obj is backed by an ndarray, not ExtensionArray
@@ -671,7 +679,13 @@ def _aggregate_series_fast(self, obj: Series, func):
671679
return result, counts
672680

673681
def _aggregate_series_pure_python(
674-
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
682+
self,
683+
obj: Series,
684+
func: F,
685+
*args,
686+
engine: str = "cython",
687+
engine_kwargs=None,
688+
**kwargs,
675689
):
676690

677691
if engine == "numba":
@@ -860,11 +874,11 @@ def result_index(self):
860874
return self.binlabels
861875

862876
@property
863-
def levels(self):
877+
def levels(self) -> List[Index]:
864878
return [self.binlabels]
865879

866880
@property
867-
def names(self):
881+
def names(self) -> List[Label]:
868882
return [self.binlabels.name]
869883

870884
@property
@@ -875,7 +889,13 @@ def groupings(self) -> "List[grouper.Grouping]":
875889
]
876890

877891
def agg_series(
878-
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
892+
self,
893+
obj: Series,
894+
func: F,
895+
*args,
896+
engine: str = "cython",
897+
engine_kwargs=None,
898+
**kwargs,
879899
):
880900
# Caller is responsible for checking ngroups != 0
881901
assert self.ngroups != 0
@@ -950,7 +970,7 @@ def _chop(self, sdata: Series, slice_obj: slice) -> Series:
950970

951971

952972
class FrameSplitter(DataSplitter):
953-
def fast_apply(self, f, sdata: FrameOrSeries, names):
973+
def fast_apply(self, f: F, sdata: FrameOrSeries, names):
954974
# must return keys::list, values::list, mutated::bool
955975
starts, ends = lib.generate_slices(self.slabels, self.ngroups)
956976
return libreduction.apply_frame_axis0(sdata, f, names, starts, ends)

0 commit comments

Comments
 (0)