@@ -2897,16 +2897,15 @@ def _get_cythonized_result(
2897
2897
2898
2898
ids , _ , ngroups = grouper .group_info
2899
2899
output : dict [base .OutputKey , np .ndarray ] = {}
2900
- base_func = getattr (libgroupby , how )
2901
-
2902
- error_msg = ""
2903
- for idx , obj in enumerate (self ._iterate_slices ()):
2904
- name = obj .name
2905
- values = obj ._values
2906
2900
2907
- if numeric_only and not is_numeric_dtype (values .dtype ):
2908
- continue
2901
+ base_func = getattr (libgroupby , how )
2902
+ base_func = partial (base_func , labels = ids )
2903
+ if needs_ngroups :
2904
+ base_func = partial (base_func , ngroups = ngroups )
2905
+ if min_count is not None :
2906
+ base_func = partial (base_func , min_count = min_count )
2909
2907
2908
+ def blk_func (values : ArrayLike ) -> ArrayLike :
2910
2909
if aggregate :
2911
2910
result_sz = ngroups
2912
2911
else :
@@ -2915,54 +2914,31 @@ def _get_cythonized_result(
2915
2914
result = np .zeros (result_sz , dtype = cython_dtype )
2916
2915
if needs_2d :
2917
2916
result = result .reshape ((- 1 , 1 ))
2918
- func = partial (base_func , result )
2917
+ func = partial (base_func , out = result )
2919
2918
2920
2919
inferences = None
2921
2920
2922
2921
if needs_counts :
2923
2922
counts = np .zeros (self .ngroups , dtype = np .int64 )
2924
- func = partial (func , counts )
2923
+ func = partial (func , counts = counts )
2925
2924
2926
2925
if needs_values :
2927
2926
vals = values
2928
2927
if pre_processing :
2929
- try :
2930
- vals , inferences = pre_processing (vals )
2931
- except TypeError as err :
2932
- error_msg = str (err )
2933
- howstr = how .replace ("group_" , "" )
2934
- warnings .warn (
2935
- "Dropping invalid columns in "
2936
- f"{ type (self ).__name__ } .{ howstr } is deprecated. "
2937
- "In a future version, a TypeError will be raised. "
2938
- f"Before calling .{ howstr } , select only columns which "
2939
- "should be valid for the function." ,
2940
- FutureWarning ,
2941
- stacklevel = 3 ,
2942
- )
2943
- continue
2928
+ vals , inferences = pre_processing (vals )
2929
+
2944
2930
vals = vals .astype (cython_dtype , copy = False )
2945
2931
if needs_2d :
2946
2932
vals = vals .reshape ((- 1 , 1 ))
2947
- func = partial (func , vals )
2948
-
2949
- func = partial (func , ids )
2950
-
2951
- if min_count is not None :
2952
- func = partial (func , min_count )
2933
+ func = partial (func , values = vals )
2953
2934
2954
2935
if needs_mask :
2955
2936
mask = isna (values ).view (np .uint8 )
2956
- func = partial (func , mask )
2957
-
2958
- if needs_ngroups :
2959
- func = partial (func , ngroups )
2937
+ func = partial (func , mask = mask )
2960
2938
2961
2939
if needs_nullable :
2962
2940
is_nullable = isinstance (values , BaseMaskedArray )
2963
2941
func = partial (func , nullable = is_nullable )
2964
- if post_processing :
2965
- post_processing = partial (post_processing , nullable = is_nullable )
2966
2942
2967
2943
func (** kwargs ) # Call func to modify indexer values in place
2968
2944
@@ -2973,9 +2949,38 @@ def _get_cythonized_result(
2973
2949
result = algorithms .take_nd (values , result )
2974
2950
2975
2951
if post_processing :
2976
- result = post_processing (result , inferences )
2952
+ pp_kwargs = {}
2953
+ if needs_nullable :
2954
+ pp_kwargs ["nullable" ] = isinstance (values , BaseMaskedArray )
2977
2955
2978
- key = base .OutputKey (label = name , position = idx )
2956
+ result = post_processing (result , inferences , ** pp_kwargs )
2957
+
2958
+ return result
2959
+
2960
+ error_msg = ""
2961
+ for idx , obj in enumerate (self ._iterate_slices ()):
2962
+ values = obj ._values
2963
+
2964
+ if numeric_only and not is_numeric_dtype (values .dtype ):
2965
+ continue
2966
+
2967
+ try :
2968
+ result = blk_func (values )
2969
+ except TypeError as err :
2970
+ error_msg = str (err )
2971
+ howstr = how .replace ("group_" , "" )
2972
+ warnings .warn (
2973
+ "Dropping invalid columns in "
2974
+ f"{ type (self ).__name__ } .{ howstr } is deprecated. "
2975
+ "In a future version, a TypeError will be raised. "
2976
+ f"Before calling .{ howstr } , select only columns which "
2977
+ "should be valid for the function." ,
2978
+ FutureWarning ,
2979
+ stacklevel = 3 ,
2980
+ )
2981
+ continue
2982
+
2983
+ key = base .OutputKey (label = obj .name , position = idx )
2979
2984
output [key ] = result
2980
2985
2981
2986
# error_msg is "" on an frame/series with no rows or columns
0 commit comments