14
14
from pandas ._libs import NaT , iNaT , lib
15
15
import pandas ._libs .groupby as libgroupby
16
16
import pandas ._libs .reduction as libreduction
17
- from pandas ._typing import FrameOrSeries
17
+ from pandas ._typing import F , FrameOrSeries , Label
18
18
from pandas .errors import AbstractMethodError
19
19
from pandas .util ._decorators import cache_readonly
20
20
@@ -110,7 +110,7 @@ def groupings(self) -> List["grouper.Grouping"]:
110
110
return self ._groupings
111
111
112
112
@property
113
- def shape (self ):
113
+ def shape (self ) -> Tuple [ int , ...] :
114
114
return tuple (ping .ngroups for ping in self .groupings )
115
115
116
116
def __iter__ (self ):
@@ -156,7 +156,7 @@ def _get_group_keys(self):
156
156
# provide "flattened" iterator for multi-group setting
157
157
return get_flattened_iterator (comp_ids , ngroups , self .levels , self .codes )
158
158
159
- def apply (self , f , data : FrameOrSeries , axis : int = 0 ):
159
+ def apply (self , f : F , data : FrameOrSeries , axis : int = 0 ):
160
160
mutated = self .mutated
161
161
splitter = self ._get_splitter (data , axis = axis )
162
162
group_keys = self ._get_group_keys ()
@@ -237,7 +237,7 @@ def levels(self) -> List[Index]:
237
237
return [ping .group_index for ping in self .groupings ]
238
238
239
239
@property
240
- def names (self ):
240
+ def names (self ) -> List [ Label ] :
241
241
return [ping .name for ping in self .groupings ]
242
242
243
243
def size (self ) -> Series :
@@ -315,7 +315,7 @@ def result_index(self) -> Index:
315
315
)
316
316
return result
317
317
318
- def get_group_levels (self ):
318
+ def get_group_levels (self ) -> List [ Index ] :
319
319
if not self .compressed and len (self .groupings ) == 1 :
320
320
return [self .groupings [0 ].result_index ]
321
321
@@ -364,7 +364,9 @@ def _is_builtin_func(self, arg):
364
364
"""
365
365
return SelectionMixin ._builtin_table .get (arg , arg )
366
366
367
- def _get_cython_function (self , kind : str , how : str , values , is_numeric : bool ):
367
+ def _get_cython_function (
368
+ self , kind : str , how : str , values : np .ndarray , is_numeric : bool
369
+ ):
368
370
369
371
dtype_str = values .dtype .name
370
372
ftype = self ._cython_functions [kind ][how ]
@@ -433,7 +435,7 @@ def _get_cython_func_and_vals(
433
435
return func , values
434
436
435
437
def _cython_operation (
436
- self , kind : str , values , how : str , axis , min_count : int = - 1 , ** kwargs
438
+ self , kind : str , values , how : str , axis : int , min_count : int = - 1 , ** kwargs
437
439
) -> Tuple [np .ndarray , Optional [List [str ]]]:
438
440
"""
439
441
Returns the values of a cython operation as a Tuple of [data, names].
@@ -617,7 +619,13 @@ def _transform(
617
619
return result
618
620
619
621
def agg_series (
620
- self , obj : Series , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs
622
+ self ,
623
+ obj : Series ,
624
+ func : F ,
625
+ * args ,
626
+ engine : str = "cython" ,
627
+ engine_kwargs = None ,
628
+ ** kwargs ,
621
629
):
622
630
# Caller is responsible for checking ngroups != 0
623
631
assert self .ngroups != 0
@@ -651,7 +659,7 @@ def agg_series(
651
659
raise
652
660
return self ._aggregate_series_pure_python (obj , func )
653
661
654
- def _aggregate_series_fast (self , obj : Series , func ):
662
+ def _aggregate_series_fast (self , obj : Series , func : F ):
655
663
# At this point we have already checked that
656
664
# - obj.index is not a MultiIndex
657
665
# - obj is backed by an ndarray, not ExtensionArray
@@ -671,7 +679,13 @@ def _aggregate_series_fast(self, obj: Series, func):
671
679
return result , counts
672
680
673
681
def _aggregate_series_pure_python (
674
- self , obj : Series , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs
682
+ self ,
683
+ obj : Series ,
684
+ func : F ,
685
+ * args ,
686
+ engine : str = "cython" ,
687
+ engine_kwargs = None ,
688
+ ** kwargs ,
675
689
):
676
690
677
691
if engine == "numba" :
@@ -860,11 +874,11 @@ def result_index(self):
860
874
return self .binlabels
861
875
862
876
@property
863
- def levels (self ):
877
+ def levels (self ) -> List [ Index ] :
864
878
return [self .binlabels ]
865
879
866
880
@property
867
- def names (self ):
881
+ def names (self ) -> List [ Label ] :
868
882
return [self .binlabels .name ]
869
883
870
884
@property
@@ -875,7 +889,13 @@ def groupings(self) -> "List[grouper.Grouping]":
875
889
]
876
890
877
891
def agg_series (
878
- self , obj : Series , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs
892
+ self ,
893
+ obj : Series ,
894
+ func : F ,
895
+ * args ,
896
+ engine : str = "cython" ,
897
+ engine_kwargs = None ,
898
+ ** kwargs ,
879
899
):
880
900
# Caller is responsible for checking ngroups != 0
881
901
assert self .ngroups != 0
@@ -950,7 +970,7 @@ def _chop(self, sdata: Series, slice_obj: slice) -> Series:
950
970
951
971
952
972
class FrameSplitter (DataSplitter ):
953
- def fast_apply (self , f , sdata : FrameOrSeries , names ):
973
+ def fast_apply (self , f : F , sdata : FrameOrSeries , names ):
954
974
# must return keys::list, values::list, mutated::bool
955
975
starts , ends = lib .generate_slices (self .slabels , self .ngroups )
956
976
return libreduction .apply_frame_axis0 (sdata , f , names , starts , ends )
0 commit comments