@@ -734,7 +734,14 @@ def _get_splitter(self, data: NDFrame, axis: AxisInt = 0) -> DataSplitter:
734
734
Generator yielding subsetted objects
735
735
"""
736
736
ids , _ , ngroups = self .group_info
737
- return _get_splitter (data , ids , ngroups , axis = axis )
737
+ return _get_splitter (
738
+ data ,
739
+ ids ,
740
+ ngroups ,
741
+ sorted_ids = self ._sorted_ids ,
742
+ sort_idx = self ._sort_idx ,
743
+ axis = axis ,
744
+ )
738
745
739
746
@final
740
747
@cache_readonly
@@ -747,45 +754,6 @@ def group_keys_seq(self):
747
754
# provide "flattened" iterator for multi-group setting
748
755
return get_flattened_list (ids , ngroups , self .levels , self .codes )
749
756
750
- @final
751
- def apply_groupwise (
752
- self , f : Callable , data : DataFrame | Series , axis : AxisInt = 0
753
- ) -> tuple [list , bool ]:
754
- mutated = False
755
- splitter = self ._get_splitter (data , axis = axis )
756
- group_keys = self .group_keys_seq
757
- result_values = []
758
-
759
- # This calls DataSplitter.__iter__
760
- zipped = zip (group_keys , splitter )
761
-
762
- for key , group in zipped :
763
- # Pinning name is needed for
764
- # test_group_apply_once_per_group,
765
- # test_inconsistent_return_type, test_set_group_name,
766
- # test_group_name_available_in_inference_pass,
767
- # test_groupby_multi_timezone
768
- object .__setattr__ (group , "name" , key )
769
-
770
- # group might be modified
771
- group_axes = group .axes
772
- res = f (group )
773
- if not mutated and not _is_indexed_like (res , group_axes , axis ):
774
- mutated = True
775
- result_values .append (res )
776
- # getattr pattern for __name__ is needed for functools.partial objects
777
- if len (group_keys ) == 0 and getattr (f , "__name__" , None ) in [
778
- "skew" ,
779
- "sum" ,
780
- "prod" ,
781
- ]:
782
- # If group_keys is empty, then no function calls have been made,
783
- # so we will not have raised even if this is an invalid dtype.
784
- # So do one dummy call here to raise appropriate TypeError.
785
- f (data .iloc [:0 ])
786
-
787
- return result_values , mutated
788
-
789
757
@cache_readonly
790
758
def indices (self ) -> dict [Hashable , npt .NDArray [np .intp ]]:
791
759
"""dict {group name -> group indices}"""
@@ -1029,6 +997,61 @@ def _aggregate_series_pure_python(
1029
997
1030
998
return result
1031
999
1000
+ @final
1001
+ def apply_groupwise (
1002
+ self , f : Callable , data : DataFrame | Series , axis : AxisInt = 0
1003
+ ) -> tuple [list , bool ]:
1004
+ mutated = False
1005
+ splitter = self ._get_splitter (data , axis = axis )
1006
+ group_keys = self .group_keys_seq
1007
+ result_values = []
1008
+
1009
+ # This calls DataSplitter.__iter__
1010
+ zipped = zip (group_keys , splitter )
1011
+
1012
+ for key , group in zipped :
1013
+ # Pinning name is needed for
1014
+ # test_group_apply_once_per_group,
1015
+ # test_inconsistent_return_type, test_set_group_name,
1016
+ # test_group_name_available_in_inference_pass,
1017
+ # test_groupby_multi_timezone
1018
+ object .__setattr__ (group , "name" , key )
1019
+
1020
+ # group might be modified
1021
+ group_axes = group .axes
1022
+ res = f (group )
1023
+ if not mutated and not _is_indexed_like (res , group_axes , axis ):
1024
+ mutated = True
1025
+ result_values .append (res )
1026
+ # getattr pattern for __name__ is needed for functools.partial objects
1027
+ if len (group_keys ) == 0 and getattr (f , "__name__" , None ) in [
1028
+ "skew" ,
1029
+ "sum" ,
1030
+ "prod" ,
1031
+ ]:
1032
+ # If group_keys is empty, then no function calls have been made,
1033
+ # so we will not have raised even if this is an invalid dtype.
1034
+ # So do one dummy call here to raise appropriate TypeError.
1035
+ f (data .iloc [:0 ])
1036
+
1037
+ return result_values , mutated
1038
+
1039
+ # ------------------------------------------------------------
1040
+ # Methods for sorting subsets of our GroupBy's object
1041
+
1042
+ @final
1043
+ @cache_readonly
1044
+ def _sort_idx (self ) -> npt .NDArray [np .intp ]:
1045
+ # Counting sort indexer
1046
+ ids , _ , ngroups = self .group_info
1047
+ return get_group_index_sorter (ids , ngroups )
1048
+
1049
+ @final
1050
+ @cache_readonly
1051
+ def _sorted_ids (self ) -> npt .NDArray [np .intp ]:
1052
+ ids , _ , _ = self .group_info
1053
+ return ids .take (self ._sort_idx )
1054
+
1032
1055
1033
1056
class BinGrouper (BaseGrouper ):
1034
1057
"""
@@ -1211,25 +1234,21 @@ def __init__(
1211
1234
data : NDFrameT ,
1212
1235
labels : npt .NDArray [np .intp ],
1213
1236
ngroups : int ,
1237
+ * ,
1238
+ sort_idx : npt .NDArray [np .intp ],
1239
+ sorted_ids : npt .NDArray [np .intp ],
1214
1240
axis : AxisInt = 0 ,
1215
1241
) -> None :
1216
1242
self .data = data
1217
1243
self .labels = ensure_platform_int (labels ) # _should_ already be np.intp
1218
1244
self .ngroups = ngroups
1219
1245
1246
+ self ._slabels = sorted_ids
1247
+ self ._sort_idx = sort_idx
1248
+
1220
1249
self .axis = axis
1221
1250
assert isinstance (axis , int ), axis
1222
1251
1223
- @cache_readonly
1224
- def _slabels (self ) -> npt .NDArray [np .intp ]:
1225
- # Sorted labels
1226
- return self .labels .take (self ._sort_idx )
1227
-
1228
- @cache_readonly
1229
- def _sort_idx (self ) -> npt .NDArray [np .intp ]:
1230
- # Counting sort indexer
1231
- return get_group_index_sorter (self .labels , self .ngroups )
1232
-
1233
1252
def __iter__ (self ) -> Iterator :
1234
1253
sdata = self ._sorted_data
1235
1254
@@ -1272,12 +1291,20 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
1272
1291
1273
1292
1274
1293
def _get_splitter (
1275
- data : NDFrame , labels : np .ndarray , ngroups : int , axis : AxisInt = 0
1294
+ data : NDFrame ,
1295
+ labels : npt .NDArray [np .intp ],
1296
+ ngroups : int ,
1297
+ * ,
1298
+ sort_idx : npt .NDArray [np .intp ],
1299
+ sorted_ids : npt .NDArray [np .intp ],
1300
+ axis : AxisInt = 0 ,
1276
1301
) -> DataSplitter :
1277
1302
if isinstance (data , Series ):
1278
1303
klass : type [DataSplitter ] = SeriesSplitter
1279
1304
else :
1280
1305
# i.e. DataFrame
1281
1306
klass = FrameSplitter
1282
1307
1283
- return klass (data , labels , ngroups , axis )
1308
+ return klass (
1309
+ data , labels , ngroups , sort_idx = sort_idx , sorted_ids = sorted_ids , axis = axis
1310
+ )
0 commit comments