Skip to content

Commit 450ef27

Browse files
jorisvandenbosscheJulianWgs
authored andcommitted
PERF: reduce overhead in groupby _cython_operation (pandas-dev#40317)
1 parent 9089814 commit 450ef27

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

pandas/core/dtypes/missing.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,12 @@ def infer_fill_value(val):
556556
return np.nan
557557

558558

559-
def maybe_fill(arr, fill_value=np.nan):
559+
def maybe_fill(arr: np.ndarray) -> np.ndarray:
560560
"""
561-
if we have a compatible fill_value and arr dtype, then fill
561+
Fill numpy.ndarray with NaN, unless we have a integer or boolean dtype.
562562
"""
563-
if isna_compat(arr, fill_value):
564-
arr.fill(fill_value)
563+
if arr.dtype.kind not in ("u", "i", "b"):
564+
arr.fill(np.nan)
565565
return arr
566566

567567

pandas/core/groupby/ops.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import pandas._libs.groupby as libgroupby
3232
import pandas._libs.reduction as libreduction
3333
from pandas._typing import (
34-
ArrayLike,
34+
DtypeObj,
3535
F,
3636
FrameOrSeries,
3737
Shape,
@@ -46,7 +46,6 @@
4646
maybe_downcast_to_dtype,
4747
)
4848
from pandas.core.dtypes.common import (
49-
ensure_float,
5049
ensure_float64,
5150
ensure_int64,
5251
ensure_int_or_float,
@@ -491,7 +490,9 @@ def _get_cython_func_and_vals(
491490
return func, values
492491

493492
@final
494-
def _disallow_invalid_ops(self, values: ArrayLike, how: str):
493+
def _disallow_invalid_ops(
494+
self, dtype: DtypeObj, how: str, is_numeric: bool = False
495+
):
495496
"""
496497
Check if we can do this operation with our cython functions.
497498
@@ -501,7 +502,9 @@ def _disallow_invalid_ops(self, values: ArrayLike, how: str):
501502
This is either not a valid function for this dtype, or
502503
valid but not implemented in cython.
503504
"""
504-
dtype = values.dtype
505+
if is_numeric:
506+
# never an invalid op for those dtypes, so return early as fastpath
507+
return
505508

506509
if is_categorical_dtype(dtype) or is_sparse(dtype):
507510
# categoricals are only 1d, so we
@@ -589,32 +592,34 @@ def _cython_operation(
589592
# as we can have 1D ExtensionArrays that we need to treat as 2D
590593
assert axis == 1, axis
591594

595+
dtype = values.dtype
596+
is_numeric = is_numeric_dtype(dtype)
597+
592598
# can we do this operation with our cython functions
593599
# if not raise NotImplementedError
594-
self._disallow_invalid_ops(values, how)
600+
self._disallow_invalid_ops(dtype, how, is_numeric)
595601

596-
if is_extension_array_dtype(values.dtype):
602+
if is_extension_array_dtype(dtype):
597603
return self._ea_wrap_cython_operation(
598604
kind, values, how, axis, min_count, **kwargs
599605
)
600606

601-
is_datetimelike = needs_i8_conversion(values.dtype)
602-
is_numeric = is_numeric_dtype(values.dtype)
607+
is_datetimelike = needs_i8_conversion(dtype)
603608

604609
if is_datetimelike:
605610
values = values.view("int64")
606611
is_numeric = True
607-
elif is_bool_dtype(values.dtype):
612+
elif is_bool_dtype(dtype):
608613
values = ensure_int_or_float(values)
609-
elif is_integer_dtype(values):
614+
elif is_integer_dtype(dtype):
610615
# we use iNaT for the missing value on ints
611616
# so pre-convert to guard this condition
612617
if (values == iNaT).any():
613618
values = ensure_float64(values)
614619
else:
615620
values = ensure_int_or_float(values)
616-
elif is_numeric and not is_complex_dtype(values):
617-
values = ensure_float64(ensure_float(values))
621+
elif is_numeric and not is_complex_dtype(dtype):
622+
values = ensure_float64(values)
618623
else:
619624
values = values.astype(object)
620625

@@ -649,20 +654,18 @@ def _cython_operation(
649654
codes, _, _ = self.group_info
650655

651656
if kind == "aggregate":
652-
result = maybe_fill(np.empty(out_shape, dtype=out_dtype), fill_value=np.nan)
657+
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
653658
counts = np.zeros(self.ngroups, dtype=np.int64)
654659
result = self._aggregate(result, counts, values, codes, func, min_count)
655660
elif kind == "transform":
656-
result = maybe_fill(
657-
np.empty(values.shape, dtype=out_dtype), fill_value=np.nan
658-
)
661+
result = maybe_fill(np.empty(values.shape, dtype=out_dtype))
659662

660663
# TODO: min_count
661664
result = self._transform(
662665
result, values, codes, func, is_datetimelike, **kwargs
663666
)
664667

665-
if is_integer_dtype(result) and not is_datetimelike:
668+
if is_integer_dtype(result.dtype) and not is_datetimelike:
666669
mask = result == iNaT
667670
if mask.any():
668671
result = result.astype("float64")
@@ -682,9 +685,9 @@ def _cython_operation(
682685
# e.g. if we are int64 and need to restore to datetime64/timedelta64
683686
# "rank" is the only member of cython_cast_blocklist we get here
684687
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
685-
# error: Argument 2 to "maybe_downcast_to_dtype" has incompatible type
686-
# "Union[dtype[Any], ExtensionDtype]"; expected "Union[str, dtype[Any]]"
687-
result = maybe_downcast_to_dtype(result, dtype) # type: ignore[arg-type]
688+
# error: Incompatible types in assignment (expression has type
689+
# "Union[ExtensionArray, ndarray]", variable has type "ndarray")
690+
result = maybe_downcast_to_dtype(result, dtype) # type: ignore[assignment]
688691

689692
return result
690693

0 commit comments

Comments
 (0)