From 846c16ab1b1e666cc8c2118e3dd4b4114b053a4e Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 3 Dec 2021 18:00:43 +0100 Subject: [PATCH] PERF: reducing dtype checking overhead in groupby --- pandas/core/dtypes/common.py | 4 +++ pandas/core/groupby/ops.py | 55 +++++++++++++++++++++--------------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index da69e70b89072..21cf3c53aacd7 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1223,6 +1223,10 @@ def is_numeric_dtype(arr_or_dtype) -> bool: >>> is_numeric_dtype(np.array([], dtype=np.timedelta64)) False """ + try: + return arr_or_dtype.kind in "uifcb" + except AttributeError: + pass return _is_dtype_type( arr_or_dtype, classes_and_not_datetimelike(np.number, np.bool_) ) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 7915e107afae6..d3c7e8760c7a2 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -23,6 +23,7 @@ from pandas._libs import ( NaT, + Period, lib, ) import pandas._libs.groupby as libgroupby @@ -46,12 +47,10 @@ ensure_int64, ensure_platform_int, is_1d_only_ea_obj, - is_bool_dtype, is_categorical_dtype, is_complex_dtype, is_datetime64_any_dtype, is_float_dtype, - is_integer_dtype, is_numeric_dtype, is_sparse, is_timedelta64_dtype, @@ -262,27 +261,35 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape: out_shape = (ngroups,) + values.shape[1:] return out_shape - def get_out_dtype(self, dtype: np.dtype) -> np.dtype: - how = self.how - + # Note: we make this a classmethod and pass kind+how so that caching + # works at the class level and not the instance level + @classmethod + @functools.lru_cache(maxsize=None) + def get_out_dtype(cls, how: str, dtype: np.dtype) -> np.dtype: if how == "rank": out_dtype = "float64" else: - if is_numeric_dtype(dtype): + if dtype.kind in "uifcb": out_dtype = f"{dtype.kind}{dtype.itemsize}" else: out_dtype = "object" return np.dtype(out_dtype) @overload - def _get_result_dtype(self, dtype: np.dtype) -> np.dtype: + @classmethod + def _get_result_dtype(cls, dtype: np.dtype) -> np.dtype: ... # pragma: no cover @overload - def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype: + @classmethod + def _get_result_dtype(cls, dtype: ExtensionDtype) -> ExtensionDtype: ... # pragma: no cover - def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: + # Note: we make this a classmethod and pass kind+how so that caching + # works at the class level and not the instance level + @classmethod + @functools.lru_cache(maxsize=None) + def _get_result_dtype(cls, how: str, dtype: DtypeObj) -> DtypeObj: """ Get the desired dtype of a result based on the input dtype and how it was computed. @@ -297,8 +304,6 @@ def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: np.dtype or ExtensionDtype The desired dtype of the result. """ - how = self.how - if how in ["add", "cumsum", "sum", "prod"]: if dtype == np.dtype(bool): return np.dtype(np.int64) @@ -382,7 +387,7 @@ def _reconstruct_ea_result(self, values, res_values): if isinstance( values.dtype, (BooleanDtype, _IntegerDtype, FloatingDtype, StringDtype) ): - dtype = self._get_result_dtype(values.dtype) + dtype = self._get_result_dtype(self.how, values.dtype) cls = dtype.construct_array_type() return cls._from_sequence(res_values, dtype=dtype) @@ -422,7 +427,7 @@ def _masked_ea_wrap_cython_operation( **kwargs, ) - dtype = self._get_result_dtype(orig_values.dtype) + dtype = self._get_result_dtype(self.how, orig_values.dtype) assert isinstance(dtype, BaseMaskedDtype) cls = dtype.construct_array_type() @@ -490,21 +495,26 @@ def _call_cython_op( orig_values = values dtype = values.dtype - is_numeric = is_numeric_dtype(dtype) + dtype_kind = dtype.kind + # is_numeric_dtype + is_numeric = dtype_kind in "uifcb" - is_datetimelike = needs_i8_conversion(dtype) + # is_datetimelike = needs_i8_conversion(dtype) + is_datetimelike = dtype_kind in ["m", "M"] or ( + dtype_kind == "O" and dtype.type is Period + ) if is_datetimelike: values = values.view("int64") is_numeric = True - elif is_bool_dtype(dtype): + elif dtype_kind == "b": values = values.astype("int64") - elif is_integer_dtype(dtype): + elif dtype_kind in "ui": # e.g. uint8 -> uint64, int16 -> int64 dtype_str = dtype.kind + "8" values = values.astype(dtype_str, copy=False) elif is_numeric: - if not is_complex_dtype(dtype): + if not dtype_kind == "c": values = ensure_float64(values) values = values.T @@ -515,7 +525,7 @@ def _call_cython_op( out_shape = self._get_output_shape(ngroups, values) func, values = self.get_cython_func_and_vals(values, is_numeric) - out_dtype = self.get_out_dtype(values.dtype) + out_dtype = self.get_out_dtype(self.how, values.dtype) result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) if self.kind == "aggregate": @@ -562,7 +572,7 @@ def _call_cython_op( # i.e. counts is defined. Locations where count