36
36
)
37
37
from pandas .core .dtypes .missing import _maybe_fill , isna
38
38
39
+ from pandas ._typing import FrameOrSeries
39
40
import pandas .core .algorithms as algorithms
40
41
from pandas .core .base import SelectionMixin
41
42
import pandas .core .common as com
@@ -89,12 +90,16 @@ def __init__(
89
90
90
91
self ._filter_empty_groups = self .compressed = len (groupings ) != 1
91
92
self .axis = axis
92
- self .groupings = groupings # type: Sequence [grouper.Grouping]
93
+ self ._groupings = list ( groupings ) # type: List [grouper.Grouping]
93
94
self .sort = sort
94
95
self .group_keys = group_keys
95
96
self .mutated = mutated
96
97
self .indexer = indexer
97
98
99
+ @property
100
+ def groupings (self ) -> List ["grouper.Grouping" ]:
101
+ return self ._groupings
102
+
98
103
@property
99
104
def shape (self ):
100
105
return tuple (ping .ngroups for ping in self .groupings )
@@ -106,7 +111,7 @@ def __iter__(self):
106
111
def nkeys (self ) -> int :
107
112
return len (self .groupings )
108
113
109
- def get_iterator (self , data , axis = 0 ):
114
+ def get_iterator (self , data : FrameOrSeries , axis : int = 0 ):
110
115
"""
111
116
Groupby iterator
112
117
@@ -120,7 +125,7 @@ def get_iterator(self, data, axis=0):
120
125
for key , (i , group ) in zip (keys , splitter ):
121
126
yield key , group
122
127
123
- def _get_splitter (self , data , axis = 0 ) :
128
+ def _get_splitter (self , data : FrameOrSeries , axis : int = 0 ) -> "DataSplitter" :
124
129
comp_ids , _ , ngroups = self .group_info
125
130
return get_splitter (data , comp_ids , ngroups , axis = axis )
126
131
@@ -142,13 +147,13 @@ def _get_group_keys(self):
142
147
# provide "flattened" iterator for multi-group setting
143
148
return get_flattened_iterator (comp_ids , ngroups , self .levels , self .codes )
144
149
145
- def apply (self , f , data , axis : int = 0 ):
150
+ def apply (self , f , data : FrameOrSeries , axis : int = 0 ):
146
151
mutated = self .mutated
147
152
splitter = self ._get_splitter (data , axis = axis )
148
153
group_keys = self ._get_group_keys ()
149
154
result_values = None
150
155
151
- sdata = splitter ._get_sorted_data ()
156
+ sdata = splitter ._get_sorted_data () # type: FrameOrSeries
152
157
if sdata .ndim == 2 and np .any (sdata .dtypes .apply (is_extension_array_dtype )):
153
158
# calling splitter.fast_apply will raise TypeError via apply_frame_axis0
154
159
# if we pass EA instead of ndarray
@@ -157,7 +162,7 @@ def apply(self, f, data, axis: int = 0):
157
162
158
163
elif (
159
164
com .get_callable_name (f ) not in base .plotting_methods
160
- and hasattr (splitter , "fast_apply" )
165
+ and isinstance (splitter , FrameSplitter )
161
166
and axis == 0
162
167
# with MultiIndex, apply_frame_axis0 would raise InvalidApply
163
168
# TODO: can we make this check prettier?
@@ -229,8 +234,7 @@ def names(self):
229
234
230
235
def size (self ) -> Series :
231
236
"""
232
- Compute group sizes
233
-
237
+ Compute group sizes.
234
238
"""
235
239
ids , _ , ngroup = self .group_info
236
240
ids = ensure_platform_int (ids )
@@ -292,7 +296,7 @@ def reconstructed_codes(self) -> List[np.ndarray]:
292
296
return decons_obs_group_ids (comp_ids , obs_ids , self .shape , codes , xnull = True )
293
297
294
298
@cache_readonly
295
- def result_index (self ):
299
+ def result_index (self ) -> Index :
296
300
if not self .compressed and len (self .groupings ) == 1 :
297
301
return self .groupings [0 ].result_index .rename (self .names [0 ])
298
302
@@ -629,7 +633,7 @@ def agg_series(self, obj: Series, func):
629
633
raise
630
634
return self ._aggregate_series_pure_python (obj , func )
631
635
632
- def _aggregate_series_fast (self , obj , func ):
636
+ def _aggregate_series_fast (self , obj : Series , func ):
633
637
# At this point we have already checked that
634
638
# - obj.index is not a MultiIndex
635
639
# - obj is backed by an ndarray, not ExtensionArray
@@ -648,7 +652,7 @@ def _aggregate_series_fast(self, obj, func):
648
652
result , counts = grouper .get_result ()
649
653
return result , counts
650
654
651
- def _aggregate_series_pure_python (self , obj , func ):
655
+ def _aggregate_series_pure_python (self , obj : Series , func ):
652
656
653
657
group_index , _ , ngroups = self .group_info
654
658
@@ -705,7 +709,12 @@ class BinGrouper(BaseGrouper):
705
709
"""
706
710
707
711
def __init__ (
708
- self , bins , binlabels , filter_empty = False , mutated = False , indexer = None
712
+ self ,
713
+ bins ,
714
+ binlabels ,
715
+ filter_empty : bool = False ,
716
+ mutated : bool = False ,
717
+ indexer = None ,
709
718
):
710
719
self .bins = ensure_int64 (bins )
711
720
self .binlabels = ensure_index (binlabels )
@@ -739,7 +748,7 @@ def _get_grouper(self):
739
748
"""
740
749
return self
741
750
742
- def get_iterator (self , data : NDFrame , axis : int = 0 ):
751
+ def get_iterator (self , data : FrameOrSeries , axis : int = 0 ):
743
752
"""
744
753
Groupby iterator
745
754
@@ -811,11 +820,9 @@ def names(self):
811
820
return [self .binlabels .name ]
812
821
813
822
@property
814
- def groupings (self ):
815
- from pandas .core .groupby .grouper import Grouping
816
-
823
+ def groupings (self ) -> "List[grouper.Grouping]" :
817
824
return [
818
- Grouping (lvl , lvl , in_axis = False , level = None , name = name )
825
+ grouper . Grouping (lvl , lvl , in_axis = False , level = None , name = name )
819
826
for lvl , name in zip (self .levels , self .names )
820
827
]
821
828
@@ -856,7 +863,7 @@ def _is_indexed_like(obj, axes) -> bool:
856
863
857
864
858
865
class DataSplitter :
859
- def __init__ (self , data , labels , ngroups , axis : int = 0 ):
866
+ def __init__ (self , data : FrameOrSeries , labels , ngroups : int , axis : int = 0 ):
860
867
self .data = data
861
868
self .labels = ensure_int64 (labels )
862
869
self .ngroups = ngroups
@@ -887,15 +894,15 @@ def __iter__(self):
887
894
for i , (start , end ) in enumerate (zip (starts , ends )):
888
895
yield i , self ._chop (sdata , slice (start , end ))
889
896
890
- def _get_sorted_data (self ):
897
+ def _get_sorted_data (self ) -> FrameOrSeries :
891
898
return self .data .take (self .sort_idx , axis = self .axis )
892
899
893
- def _chop (self , sdata , slice_obj : slice ):
900
+ def _chop (self , sdata , slice_obj : slice ) -> NDFrame :
894
901
raise AbstractMethodError (self )
895
902
896
903
897
904
class SeriesSplitter (DataSplitter ):
898
- def _chop (self , sdata , slice_obj : slice ):
905
+ def _chop (self , sdata : Series , slice_obj : slice ) -> Series :
899
906
return sdata ._get_values (slice_obj )
900
907
901
908
@@ -907,14 +914,14 @@ def fast_apply(self, f, names):
907
914
sdata = self ._get_sorted_data ()
908
915
return libreduction .apply_frame_axis0 (sdata , f , names , starts , ends )
909
916
910
- def _chop (self , sdata , slice_obj : slice ):
917
+ def _chop (self , sdata : DataFrame , slice_obj : slice ) -> DataFrame :
911
918
if self .axis == 0 :
912
919
return sdata .iloc [slice_obj ]
913
920
else :
914
921
return sdata ._slice (slice_obj , axis = 1 )
915
922
916
923
917
- def get_splitter (data : NDFrame , * args , ** kwargs ):
924
+ def get_splitter (data : FrameOrSeries , * args , ** kwargs ) -> DataSplitter :
918
925
if isinstance (data , Series ):
919
926
klass = SeriesSplitter # type: Type[DataSplitter]
920
927
else :
0 commit comments