From 07e59eb84ac28b4eb02f4dcb23397b8e9175b4f4 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 1 Apr 2021 09:34:38 -0700 Subject: [PATCH 1/2] REF: implement groupby.ops.WrappedCythonFunc --- pandas/core/groupby/ops.py | 334 ++++++++++++++++++------------------- 1 file changed, 164 insertions(+), 170 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 4c086f3b8612e..0ca63bfd7d871 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -97,54 +97,164 @@ get_indexer_dict, ) -_CYTHON_FUNCTIONS = { - "aggregate": { - "add": "group_add", - "prod": "group_prod", - "min": "group_min", - "max": "group_max", - "mean": "group_mean", - "median": "group_median", - "var": "group_var", - "first": "group_nth", - "last": "group_last", - "ohlc": "group_ohlc", - }, - "transform": { - "cumprod": "group_cumprod", - "cumsum": "group_cumsum", - "cummin": "group_cummin", - "cummax": "group_cummax", - "rank": "group_rank", - }, -} - - -@functools.lru_cache(maxsize=None) -def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool): - - dtype_str = dtype.name - ftype = _CYTHON_FUNCTIONS[kind][how] - - # see if there is a fused-type version of function - # only valid for numeric - f = getattr(libgroupby, ftype, None) - if f is not None: - if is_numeric: - return f - elif dtype == object: - if "object" not in f.__signatures__: - # raise NotImplementedError here rather than TypeError later + +class WrappedCythonOp: + """ + Dispatch logic for functions defined in _libs.groupby + """ + + _CYTHON_FUNCTIONS = { + "aggregate": { + "add": "group_add", + "prod": "group_prod", + "min": "group_min", + "max": "group_max", + "mean": "group_mean", + "median": "group_median", + "var": "group_var", + "first": "group_nth", + "last": "group_last", + "ohlc": "group_ohlc", + }, + "transform": { + "cumprod": "group_cumprod", + "cumsum": "group_cumsum", + "cummin": "group_cummin", + "cummax": "group_cummax", + "rank": "group_rank", + }, + } + + _cython_arity = {"ohlc": 4} # OHLC + + @staticmethod + @functools.lru_cache(maxsize=None) + def get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool): + + dtype_str = dtype.name + ftype = WrappedCythonOp._CYTHON_FUNCTIONS[kind][how] + + # see if there is a fused-type version of function + # only valid for numeric + f = getattr(libgroupby, ftype, None) + if f is not None: + if is_numeric: + return f + elif dtype == object: + if "object" not in f.__signatures__: + # raise NotImplementedError here rather than TypeError later + raise NotImplementedError( + f"function is not implemented for this dtype: " + f"[how->{how},dtype->{dtype_str}]" + ) + return f + + raise NotImplementedError( + f"function is not implemented for this dtype: " + f"[how->{how},dtype->{dtype_str}]" + ) + + @classmethod + def get_cython_func_and_vals( + cls, kind: str, how: str, values: np.ndarray, is_numeric: bool + ): + """ + Find the appropriate cython function, casting if necessary. + + Parameters + ---------- + kind : str + how : str + values : np.ndarray + is_numeric : bool + + Returns + ------- + func : callable + values : np.ndarray + """ + if how in ["median", "cumprod"]: + # these two only have float64 implementations + if is_numeric: + values = ensure_float64(values) + else: raise NotImplementedError( f"function is not implemented for this dtype: " - f"[how->{how},dtype->{dtype_str}]" + f"[how->{how},dtype->{values.dtype.name}]" ) - return f + func = getattr(libgroupby, f"group_{how}_float64") + return func, values + + func = cls.get_cython_function(kind, how, values.dtype, is_numeric) + + if values.dtype.kind in ["i", "u"]: + if how in ["add", "var", "prod", "mean", "ohlc"]: + # result may still include NaN, so we have to cast + values = ensure_float64(values) - raise NotImplementedError( - f"function is not implemented for this dtype: " - f"[how->{how},dtype->{dtype_str}]" - ) + return func, values + + @staticmethod + def disallow_invalid_ops(dtype: DtypeObj, how: str, is_numeric: bool = False): + """ + Check if we can do this operation with our cython functions. + + Raises + ------ + NotImplementedError + This is either not a valid function for this dtype, or + valid but not implemented in cython. + """ + if is_numeric: + # never an invalid op for those dtypes, so return early as fastpath + return + + if is_categorical_dtype(dtype) or is_sparse(dtype): + # categoricals are only 1d, so we + # are not setup for dim transforming + raise NotImplementedError(f"{dtype} dtype not supported") + elif is_datetime64_any_dtype(dtype): + # we raise NotImplemented if this is an invalid operation + # entirely, e.g. adding datetimes + if how in ["add", "prod", "cumsum", "cumprod"]: + raise NotImplementedError( + f"datetime64 type does not support {how} operations" + ) + elif is_timedelta64_dtype(dtype): + if how in ["prod", "cumprod"]: + raise NotImplementedError( + f"timedelta64 type does not support {how} operations" + ) + + @staticmethod + def get_output_shape( + how: str, kind: str, ngroups: int, values: np.ndarray + ) -> Shape: + + arity = WrappedCythonOp._cython_arity.get(how, 1) + + if how == "ohlc": + out_shape = (ngroups, 4) + elif arity > 1: + raise NotImplementedError( + "arity of more than 1 is not supported for the 'how' argument" + ) + elif kind == "transform": + out_shape = values.shape + else: + out_shape = (ngroups,) + values.shape[1:] + return out_shape + + @staticmethod + def get_out_dtype(how: str, dtype: np.dtype) -> np.dtype: + if how == "rank": + out_dtype = "float64" + else: + if is_numeric_dtype(dtype): + out_dtype = f"{dtype.kind}{dtype.itemsize}" + else: + out_dtype = "object" + return np.dtype(out_dtype) class BaseGrouper: @@ -437,8 +547,6 @@ def get_group_levels(self) -> List[Index]: # ------------------------------------------------------------ # Aggregation functions - _cython_arity = {"ohlc": 4} # OHLC - @final def _is_builtin_func(self, arg): """ @@ -447,80 +555,6 @@ def _is_builtin_func(self, arg): """ return SelectionMixin._builtin_table.get(arg, arg) - @final - def _get_cython_func_and_vals( - self, kind: str, how: str, values: np.ndarray, is_numeric: bool - ): - """ - Find the appropriate cython function, casting if necessary. - - Parameters - ---------- - kind : str - how : str - values : np.ndarray - is_numeric : bool - - Returns - ------- - func : callable - values : np.ndarray - """ - if how in ["median", "cumprod"]: - # these two only have float64 implementations - if is_numeric: - values = ensure_float64(values) - else: - raise NotImplementedError( - f"function is not implemented for this dtype: " - f"[how->{how},dtype->{values.dtype.name}]" - ) - func = getattr(libgroupby, f"group_{how}_float64") - return func, values - - func = _get_cython_function(kind, how, values.dtype, is_numeric) - - if values.dtype.kind in ["i", "u"]: - if how in ["add", "var", "prod", "mean", "ohlc"]: - # result may still include NaN, so we have to cast - values = ensure_float64(values) - - return func, values - - @final - def _disallow_invalid_ops( - self, dtype: DtypeObj, how: str, is_numeric: bool = False - ): - """ - Check if we can do this operation with our cython functions. - - Raises - ------ - NotImplementedError - This is either not a valid function for this dtype, or - valid but not implemented in cython. - """ - if is_numeric: - # never an invalid op for those dtypes, so return early as fastpath - return - - if is_categorical_dtype(dtype) or is_sparse(dtype): - # categoricals are only 1d, so we - # are not setup for dim transforming - raise NotImplementedError(f"{dtype} dtype not supported") - elif is_datetime64_any_dtype(dtype): - # we raise NotImplemented if this is an invalid operation - # entirely, e.g. adding datetimes - if how in ["add", "prod", "cumsum", "cumprod"]: - raise NotImplementedError( - f"datetime64 type does not support {how} operations" - ) - elif is_timedelta64_dtype(dtype): - if how in ["prod", "cumprod"]: - raise NotImplementedError( - f"timedelta64 type does not support {how} operations" - ) - @final def _ea_wrap_cython_operation( self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs @@ -595,7 +629,7 @@ def _cython_operation( # can we do this operation with our cython functions # if not raise NotImplementedError - self._disallow_invalid_ops(dtype, how, is_numeric) + WrappedCythonOp.disallow_invalid_ops(dtype, how, is_numeric) if is_extension_array_dtype(dtype): return self._ea_wrap_cython_operation( @@ -637,43 +671,25 @@ def _cython_operation( if not is_complex_dtype(dtype): values = ensure_float64(values) - arity = self._cython_arity.get(how, 1) ngroups = self.ngroups + comp_ids, _, _ = self.group_info assert axis == 1 values = values.T - if how == "ohlc": - out_shape = (ngroups, 4) - elif arity > 1: - raise NotImplementedError( - "arity of more than 1 is not supported for the 'how' argument" - ) - elif kind == "transform": - out_shape = values.shape - else: - out_shape = (ngroups,) + values.shape[1:] - - func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric) - - if how == "rank": - out_dtype = "float" - else: - if is_numeric: - out_dtype = f"{values.dtype.kind}{values.dtype.itemsize}" - else: - out_dtype = "object" - codes, _, _ = self.group_info + out_shape = WrappedCythonOp.get_output_shape(how, kind, ngroups, values) + func, values = WrappedCythonOp.get_cython_func_and_vals( + kind, how, values, is_numeric + ) + out_dtype = WrappedCythonOp.get_out_dtype(how, values.dtype) result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) if kind == "aggregate": - counts = np.zeros(self.ngroups, dtype=np.int64) - result = self._aggregate(result, counts, values, codes, func, min_count) + counts = np.zeros(ngroups, dtype=np.int64) + func(result, counts, values, comp_ids, min_count) elif kind == "transform": # TODO: min_count - result = self._transform( - result, values, codes, func, is_datetimelike, **kwargs - ) + func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) if is_integer_dtype(result.dtype) and not is_datetimelike: mask = result == iNaT @@ -697,28 +713,6 @@ def _cython_operation( return op_result - @final - def _aggregate( - self, result, counts, values, comp_ids, agg_func, min_count: int = -1 - ): - if agg_func is libgroupby.group_nth: - # different signature from the others - agg_func(result, counts, values, comp_ids, min_count, rank=1) - else: - agg_func(result, counts, values, comp_ids, min_count) - - return result - - @final - def _transform( - self, result, values, comp_ids, transform_func, is_datetimelike: bool, **kwargs - ): - - _, _, ngroups = self.group_info - transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) - - return result - def agg_series(self, obj: Series, func: F): # Caller is responsible for checking ngroups != 0 assert self.ngroups != 0 From b41b8ddc7f1c0cd89d170d72163bc6c2a7ed6c57 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 1 Apr 2021 14:38:57 -0700 Subject: [PATCH 2/2] instantiate WrappedCythonOp, no staticmethods --- pandas/core/groupby/ops.py | 58 ++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 0ca63bfd7d871..bcf2b6be15953 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -103,6 +103,10 @@ class WrappedCythonOp: Dispatch logic for functions defined in _libs.groupby """ + def __init__(self, kind: str, how: str): + self.kind = kind + self.how = how + _CYTHON_FUNCTIONS = { "aggregate": { "add": "group_add", @@ -127,12 +131,16 @@ class WrappedCythonOp: _cython_arity = {"ohlc": 4} # OHLC - @staticmethod + # Note: we make this a classmethod and pass kind+how so that caching + # works at the class level and not the instance level + @classmethod @functools.lru_cache(maxsize=None) - def get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool): + def _get_cython_function( + cls, kind: str, how: str, dtype: np.dtype, is_numeric: bool + ): dtype_str = dtype.name - ftype = WrappedCythonOp._CYTHON_FUNCTIONS[kind][how] + ftype = cls._CYTHON_FUNCTIONS[kind][how] # see if there is a fused-type version of function # only valid for numeric @@ -154,17 +162,12 @@ def get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool): f"[how->{how},dtype->{dtype_str}]" ) - @classmethod - def get_cython_func_and_vals( - cls, kind: str, how: str, values: np.ndarray, is_numeric: bool - ): + def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool): """ Find the appropriate cython function, casting if necessary. Parameters ---------- - kind : str - how : str values : np.ndarray is_numeric : bool @@ -173,6 +176,9 @@ def get_cython_func_and_vals( func : callable values : np.ndarray """ + how = self.how + kind = self.kind + if how in ["median", "cumprod"]: # these two only have float64 implementations if is_numeric: @@ -185,7 +191,7 @@ def get_cython_func_and_vals( func = getattr(libgroupby, f"group_{how}_float64") return func, values - func = cls.get_cython_function(kind, how, values.dtype, is_numeric) + func = self._get_cython_function(kind, how, values.dtype, is_numeric) if values.dtype.kind in ["i", "u"]: if how in ["add", "var", "prod", "mean", "ohlc"]: @@ -194,8 +200,7 @@ def get_cython_func_and_vals( return func, values - @staticmethod - def disallow_invalid_ops(dtype: DtypeObj, how: str, is_numeric: bool = False): + def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False): """ Check if we can do this operation with our cython functions. @@ -205,6 +210,8 @@ def disallow_invalid_ops(dtype: DtypeObj, how: str, is_numeric: bool = False): This is either not a valid function for this dtype, or valid but not implemented in cython. """ + how = self.how + if is_numeric: # never an invalid op for those dtypes, so return early as fastpath return @@ -226,13 +233,13 @@ def disallow_invalid_ops(dtype: DtypeObj, how: str, is_numeric: bool = False): f"timedelta64 type does not support {how} operations" ) - @staticmethod - def get_output_shape( - how: str, kind: str, ngroups: int, values: np.ndarray - ) -> Shape: + def get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape: + how = self.how + kind = self.kind - arity = WrappedCythonOp._cython_arity.get(how, 1) + arity = self._cython_arity.get(how, 1) + out_shape: Shape if how == "ohlc": out_shape = (ngroups, 4) elif arity > 1: @@ -245,8 +252,9 @@ def get_output_shape( out_shape = (ngroups,) + values.shape[1:] return out_shape - @staticmethod - def get_out_dtype(how: str, dtype: np.dtype) -> np.dtype: + def get_out_dtype(self, dtype: np.dtype) -> np.dtype: + how = self.how + if how == "rank": out_dtype = "float64" else: @@ -627,9 +635,11 @@ def _cython_operation( dtype = values.dtype is_numeric = is_numeric_dtype(dtype) + cy_op = WrappedCythonOp(kind=kind, how=how) + # can we do this operation with our cython functions # if not raise NotImplementedError - WrappedCythonOp.disallow_invalid_ops(dtype, how, is_numeric) + cy_op.disallow_invalid_ops(dtype, is_numeric) if is_extension_array_dtype(dtype): return self._ea_wrap_cython_operation( @@ -677,11 +687,9 @@ def _cython_operation( assert axis == 1 values = values.T - out_shape = WrappedCythonOp.get_output_shape(how, kind, ngroups, values) - func, values = WrappedCythonOp.get_cython_func_and_vals( - kind, how, values, is_numeric - ) - out_dtype = WrappedCythonOp.get_out_dtype(how, values.dtype) + out_shape = cy_op.get_output_shape(ngroups, values) + func, values = cy_op.get_cython_func_and_vals(values, is_numeric) + out_dtype = cy_op.get_out_dtype(values.dtype) result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) if kind == "aggregate":