Skip to content

Commit 520864a

Browse files
jbrockmendelfeefladder
authored andcommitted
REF: GroupBy._get_cythonized_result (pandas-dev#42742)
1 parent 53a4f11 commit 520864a

File tree

1 file changed

+44
-39
lines changed

1 file changed

+44
-39
lines changed

pandas/core/groupby/groupby.py

+44-39
Original file line numberDiff line numberDiff line change
@@ -2897,16 +2897,15 @@ def _get_cythonized_result(
28972897

28982898
ids, _, ngroups = grouper.group_info
28992899
output: dict[base.OutputKey, np.ndarray] = {}
2900-
base_func = getattr(libgroupby, how)
2901-
2902-
error_msg = ""
2903-
for idx, obj in enumerate(self._iterate_slices()):
2904-
name = obj.name
2905-
values = obj._values
29062900

2907-
if numeric_only and not is_numeric_dtype(values.dtype):
2908-
continue
2901+
base_func = getattr(libgroupby, how)
2902+
base_func = partial(base_func, labels=ids)
2903+
if needs_ngroups:
2904+
base_func = partial(base_func, ngroups=ngroups)
2905+
if min_count is not None:
2906+
base_func = partial(base_func, min_count=min_count)
29092907

2908+
def blk_func(values: ArrayLike) -> ArrayLike:
29102909
if aggregate:
29112910
result_sz = ngroups
29122911
else:
@@ -2915,54 +2914,31 @@ def _get_cythonized_result(
29152914
result = np.zeros(result_sz, dtype=cython_dtype)
29162915
if needs_2d:
29172916
result = result.reshape((-1, 1))
2918-
func = partial(base_func, result)
2917+
func = partial(base_func, out=result)
29192918

29202919
inferences = None
29212920

29222921
if needs_counts:
29232922
counts = np.zeros(self.ngroups, dtype=np.int64)
2924-
func = partial(func, counts)
2923+
func = partial(func, counts=counts)
29252924

29262925
if needs_values:
29272926
vals = values
29282927
if pre_processing:
2929-
try:
2930-
vals, inferences = pre_processing(vals)
2931-
except TypeError as err:
2932-
error_msg = str(err)
2933-
howstr = how.replace("group_", "")
2934-
warnings.warn(
2935-
"Dropping invalid columns in "
2936-
f"{type(self).__name__}.{howstr} is deprecated. "
2937-
"In a future version, a TypeError will be raised. "
2938-
f"Before calling .{howstr}, select only columns which "
2939-
"should be valid for the function.",
2940-
FutureWarning,
2941-
stacklevel=3,
2942-
)
2943-
continue
2928+
vals, inferences = pre_processing(vals)
2929+
29442930
vals = vals.astype(cython_dtype, copy=False)
29452931
if needs_2d:
29462932
vals = vals.reshape((-1, 1))
2947-
func = partial(func, vals)
2948-
2949-
func = partial(func, ids)
2950-
2951-
if min_count is not None:
2952-
func = partial(func, min_count)
2933+
func = partial(func, values=vals)
29532934

29542935
if needs_mask:
29552936
mask = isna(values).view(np.uint8)
2956-
func = partial(func, mask)
2957-
2958-
if needs_ngroups:
2959-
func = partial(func, ngroups)
2937+
func = partial(func, mask=mask)
29602938

29612939
if needs_nullable:
29622940
is_nullable = isinstance(values, BaseMaskedArray)
29632941
func = partial(func, nullable=is_nullable)
2964-
if post_processing:
2965-
post_processing = partial(post_processing, nullable=is_nullable)
29662942

29672943
func(**kwargs) # Call func to modify indexer values in place
29682944

@@ -2973,9 +2949,38 @@ def _get_cythonized_result(
29732949
result = algorithms.take_nd(values, result)
29742950

29752951
if post_processing:
2976-
result = post_processing(result, inferences)
2952+
pp_kwargs = {}
2953+
if needs_nullable:
2954+
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
29772955

2978-
key = base.OutputKey(label=name, position=idx)
2956+
result = post_processing(result, inferences, **pp_kwargs)
2957+
2958+
return result
2959+
2960+
error_msg = ""
2961+
for idx, obj in enumerate(self._iterate_slices()):
2962+
values = obj._values
2963+
2964+
if numeric_only and not is_numeric_dtype(values.dtype):
2965+
continue
2966+
2967+
try:
2968+
result = blk_func(values)
2969+
except TypeError as err:
2970+
error_msg = str(err)
2971+
howstr = how.replace("group_", "")
2972+
warnings.warn(
2973+
"Dropping invalid columns in "
2974+
f"{type(self).__name__}.{howstr} is deprecated. "
2975+
"In a future version, a TypeError will be raised. "
2976+
f"Before calling .{howstr}, select only columns which "
2977+
"should be valid for the function.",
2978+
FutureWarning,
2979+
stacklevel=3,
2980+
)
2981+
continue
2982+
2983+
key = base.OutputKey(label=obj.name, position=idx)
29792984
output[key] = result
29802985

29812986
# error_msg is "" on an frame/series with no rows or columns

0 commit comments

Comments
 (0)