diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index e23fa9b82f12e..b43a0d2eced93 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -455,11 +455,11 @@ ctypedef fused complexfloating_t: @cython.wraparound(False) @cython.boundscheck(False) -def _group_add(complexfloating_t[:, ::1] out, - int64_t[::1] counts, - ndarray[complexfloating_t, ndim=2] values, - const intp_t[:] labels, - Py_ssize_t min_count=0): +def group_add(complexfloating_t[:, ::1] out, + int64_t[::1] counts, + ndarray[complexfloating_t, ndim=2] values, + const intp_t[:] labels, + Py_ssize_t min_count=0): """ Only aggregates on axis=0 using Kahan summation """ @@ -506,19 +506,13 @@ def _group_add(complexfloating_t[:, ::1] out, out[i, j] = sumx[i, j] -group_add_float32 = _group_add['float32_t'] -group_add_float64 = _group_add['float64_t'] -group_add_complex64 = _group_add['float complex'] -group_add_complex128 = _group_add['double complex'] - - @cython.wraparound(False) @cython.boundscheck(False) -def _group_prod(floating[:, ::1] out, - int64_t[::1] counts, - ndarray[floating, ndim=2] values, - const intp_t[:] labels, - Py_ssize_t min_count=0): +def group_prod(floating[:, ::1] out, + int64_t[::1] counts, + ndarray[floating, ndim=2] values, + const intp_t[:] labels, + Py_ssize_t min_count=0): """ Only aggregates on axis=0 """ @@ -560,19 +554,15 @@ def _group_prod(floating[:, ::1] out, out[i, j] = prodx[i, j] -group_prod_float32 = _group_prod['float'] -group_prod_float64 = _group_prod['double'] - - @cython.wraparound(False) @cython.boundscheck(False) @cython.cdivision(True) -def _group_var(floating[:, ::1] out, - int64_t[::1] counts, - ndarray[floating, ndim=2] values, - const intp_t[:] labels, - Py_ssize_t min_count=-1, - int64_t ddof=1): +def group_var(floating[:, ::1] out, + int64_t[::1] counts, + ndarray[floating, ndim=2] values, + const intp_t[:] labels, + Py_ssize_t min_count=-1, + int64_t ddof=1): cdef: Py_ssize_t i, j, N, K, lab, ncounts = len(counts) floating val, ct, oldmean @@ -619,17 +609,13 @@ def _group_var(floating[:, ::1] out, out[i, j] /= (ct - ddof) -group_var_float32 = _group_var['float'] -group_var_float64 = _group_var['double'] - - @cython.wraparound(False) @cython.boundscheck(False) -def _group_mean(floating[:, ::1] out, - int64_t[::1] counts, - ndarray[floating, ndim=2] values, - const intp_t[::1] labels, - Py_ssize_t min_count=-1): +def group_mean(floating[:, ::1] out, + int64_t[::1] counts, + ndarray[floating, ndim=2] values, + const intp_t[::1] labels, + Py_ssize_t min_count=-1): cdef: Py_ssize_t i, j, N, K, lab, ncounts = len(counts) floating val, count, y, t @@ -675,10 +661,6 @@ def _group_mean(floating[:, ::1] out, out[i, j] = sumx[i, j] / count -group_mean_float32 = _group_mean['float'] -group_mean_float64 = _group_mean['double'] - - @cython.wraparound(False) @cython.boundscheck(False) def group_ohlc(floating[:, ::1] out, diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 0ecd798986c53..a6c3cb3ff5d0b 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1602,7 +1602,7 @@ def std(self, ddof: int = 1): Standard deviation of values within each group. """ return self._get_cythonized_result( - "group_var_float64", + "group_var", aggregate=True, needs_counts=True, needs_values=True, diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 5442f90a25580..20b8dd99b8cd3 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -129,31 +129,22 @@ def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool) # see if there is a fused-type version of function # only valid for numeric f = getattr(libgroupby, ftype, None) - if f is not None and is_numeric: - return f - - # otherwise find dtype-specific version, falling back to object - for dt in [dtype_str, "object"]: - f2 = getattr(libgroupby, f"{ftype}_{dt}", None) - if f2 is not None: - return f2 - - if hasattr(f, "__signatures__"): - # inspect what fused types are implemented - if dtype_str == "object" and "object" not in f.__signatures__: - # disallow this function so we get a NotImplementedError below - # instead of a TypeError at runtime - f = None - - func = f - - if func is None: - raise NotImplementedError( - f"function is not implemented for this dtype: " - f"[how->{how},dtype->{dtype_str}]" - ) + if f is not None: + if is_numeric: + return f + elif dtype == object: + if "object" not in f.__signatures__: + # raise NotImplementedError here rather than TypeError later + raise NotImplementedError( + f"function is not implemented for this dtype: " + f"[how->{how},dtype->{dtype_str}]" + ) + return f - return func + raise NotImplementedError( + f"function is not implemented for this dtype: " + f"[how->{how},dtype->{dtype_str}]" + ) class BaseGrouper: @@ -475,25 +466,24 @@ def _get_cython_func_and_vals( func : callable values : np.ndarray """ - try: - func = _get_cython_function(kind, how, values.dtype, is_numeric) - except NotImplementedError: + if how in ["median", "cumprod"]: + # these two only have float64 implementations if is_numeric: - try: - values = ensure_float64(values) - except TypeError: - if lib.infer_dtype(values, skipna=False) == "complex": - values = values.astype(complex) - else: - raise - func = _get_cython_function(kind, how, values.dtype, is_numeric) + values = ensure_float64(values) else: - raise - else: - if values.dtype.kind in ["i", "u"]: - if how in ["ohlc"]: - # The output may still include nans, so we have to cast - values = ensure_float64(values) + raise NotImplementedError( + f"function is not implemented for this dtype: " + f"[how->{how},dtype->{values.dtype.name}]" + ) + func = getattr(libgroupby, f"group_{how}_float64") + return func, values + + func = _get_cython_function(kind, how, values.dtype, is_numeric) + + if values.dtype.kind in ["i", "u"]: + if how in ["add", "var", "prod", "mean", "ohlc"]: + # result may still include NaN, so we have to cast + values = ensure_float64(values) return func, values @@ -643,10 +633,9 @@ def _cython_operation( values = ensure_float64(values) else: values = ensure_int_or_float(values) - elif is_numeric and not is_complex_dtype(dtype): - values = ensure_float64(values) - else: - values = values.astype(object) + elif is_numeric: + if not is_complex_dtype(dtype): + values = ensure_float64(values) arity = self._cython_arity.get(how, 1) ngroups = self.ngroups diff --git a/pandas/tests/groupby/test_libgroupby.py b/pandas/tests/groupby/test_libgroupby.py index d776c34f5b5ec..7a9cadb6c8232 100644 --- a/pandas/tests/groupby/test_libgroupby.py +++ b/pandas/tests/groupby/test_libgroupby.py @@ -4,8 +4,7 @@ from pandas._libs.groupby import ( group_cumprod_float64, group_cumsum, - group_var_float32, - group_var_float64, + group_var, ) from pandas.core.dtypes.common import ensure_platform_int @@ -102,7 +101,7 @@ def test_group_var_constant(self): class TestGroupVarFloat64(GroupVarTestMixin): __test__ = True - algo = staticmethod(group_var_float64) + algo = staticmethod(group_var) dtype = np.float64 rtol = 1e-5 @@ -124,7 +123,7 @@ def test_group_var_large_inputs(self): class TestGroupVarFloat32(GroupVarTestMixin): __test__ = True - algo = staticmethod(group_var_float32) + algo = staticmethod(group_var) dtype = np.float32 rtol = 1e-2