Skip to content

Commit 4fe34ef

Browse files
authored
REF: simplify dispatch in groupby.ops (pandas-dev#40681)
1 parent 867dc94 commit 4fe34ef

File tree

4 files changed

+59
-89
lines changed

4 files changed

+59
-89
lines changed

pandas/_libs/groupby.pyx

+21-39
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,11 @@ ctypedef fused complexfloating_t:
455455

456456
@cython.wraparound(False)
457457
@cython.boundscheck(False)
458-
def _group_add(complexfloating_t[:, ::1] out,
459-
int64_t[::1] counts,
460-
ndarray[complexfloating_t, ndim=2] values,
461-
const intp_t[:] labels,
462-
Py_ssize_t min_count=0):
458+
def group_add(complexfloating_t[:, ::1] out,
459+
int64_t[::1] counts,
460+
ndarray[complexfloating_t, ndim=2] values,
461+
const intp_t[:] labels,
462+
Py_ssize_t min_count=0):
463463
"""
464464
Only aggregates on axis=0 using Kahan summation
465465
"""
@@ -506,19 +506,13 @@ def _group_add(complexfloating_t[:, ::1] out,
506506
out[i, j] = sumx[i, j]
507507

508508

509-
group_add_float32 = _group_add['float32_t']
510-
group_add_float64 = _group_add['float64_t']
511-
group_add_complex64 = _group_add['float complex']
512-
group_add_complex128 = _group_add['double complex']
513-
514-
515509
@cython.wraparound(False)
516510
@cython.boundscheck(False)
517-
def _group_prod(floating[:, ::1] out,
518-
int64_t[::1] counts,
519-
ndarray[floating, ndim=2] values,
520-
const intp_t[:] labels,
521-
Py_ssize_t min_count=0):
511+
def group_prod(floating[:, ::1] out,
512+
int64_t[::1] counts,
513+
ndarray[floating, ndim=2] values,
514+
const intp_t[:] labels,
515+
Py_ssize_t min_count=0):
522516
"""
523517
Only aggregates on axis=0
524518
"""
@@ -560,19 +554,15 @@ def _group_prod(floating[:, ::1] out,
560554
out[i, j] = prodx[i, j]
561555

562556

563-
group_prod_float32 = _group_prod['float']
564-
group_prod_float64 = _group_prod['double']
565-
566-
567557
@cython.wraparound(False)
568558
@cython.boundscheck(False)
569559
@cython.cdivision(True)
570-
def _group_var(floating[:, ::1] out,
571-
int64_t[::1] counts,
572-
ndarray[floating, ndim=2] values,
573-
const intp_t[:] labels,
574-
Py_ssize_t min_count=-1,
575-
int64_t ddof=1):
560+
def group_var(floating[:, ::1] out,
561+
int64_t[::1] counts,
562+
ndarray[floating, ndim=2] values,
563+
const intp_t[:] labels,
564+
Py_ssize_t min_count=-1,
565+
int64_t ddof=1):
576566
cdef:
577567
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
578568
floating val, ct, oldmean
@@ -619,17 +609,13 @@ def _group_var(floating[:, ::1] out,
619609
out[i, j] /= (ct - ddof)
620610

621611

622-
group_var_float32 = _group_var['float']
623-
group_var_float64 = _group_var['double']
624-
625-
626612
@cython.wraparound(False)
627613
@cython.boundscheck(False)
628-
def _group_mean(floating[:, ::1] out,
629-
int64_t[::1] counts,
630-
ndarray[floating, ndim=2] values,
631-
const intp_t[::1] labels,
632-
Py_ssize_t min_count=-1):
614+
def group_mean(floating[:, ::1] out,
615+
int64_t[::1] counts,
616+
ndarray[floating, ndim=2] values,
617+
const intp_t[::1] labels,
618+
Py_ssize_t min_count=-1):
633619
cdef:
634620
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
635621
floating val, count, y, t
@@ -675,10 +661,6 @@ def _group_mean(floating[:, ::1] out,
675661
out[i, j] = sumx[i, j] / count
676662

677663

678-
group_mean_float32 = _group_mean['float']
679-
group_mean_float64 = _group_mean['double']
680-
681-
682664
@cython.wraparound(False)
683665
@cython.boundscheck(False)
684666
def group_ohlc(floating[:, ::1] out,

pandas/core/groupby/groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1602,7 +1602,7 @@ def std(self, ddof: int = 1):
16021602
Standard deviation of values within each group.
16031603
"""
16041604
return self._get_cythonized_result(
1605-
"group_var_float64",
1605+
"group_var",
16061606
aggregate=True,
16071607
needs_counts=True,
16081608
needs_values=True,

pandas/core/groupby/ops.py

+34-45
Original file line numberDiff line numberDiff line change
@@ -129,31 +129,22 @@ def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool)
129129
# see if there is a fused-type version of function
130130
# only valid for numeric
131131
f = getattr(libgroupby, ftype, None)
132-
if f is not None and is_numeric:
133-
return f
134-
135-
# otherwise find dtype-specific version, falling back to object
136-
for dt in [dtype_str, "object"]:
137-
f2 = getattr(libgroupby, f"{ftype}_{dt}", None)
138-
if f2 is not None:
139-
return f2
140-
141-
if hasattr(f, "__signatures__"):
142-
# inspect what fused types are implemented
143-
if dtype_str == "object" and "object" not in f.__signatures__:
144-
# disallow this function so we get a NotImplementedError below
145-
# instead of a TypeError at runtime
146-
f = None
147-
148-
func = f
149-
150-
if func is None:
151-
raise NotImplementedError(
152-
f"function is not implemented for this dtype: "
153-
f"[how->{how},dtype->{dtype_str}]"
154-
)
132+
if f is not None:
133+
if is_numeric:
134+
return f
135+
elif dtype == object:
136+
if "object" not in f.__signatures__:
137+
# raise NotImplementedError here rather than TypeError later
138+
raise NotImplementedError(
139+
f"function is not implemented for this dtype: "
140+
f"[how->{how},dtype->{dtype_str}]"
141+
)
142+
return f
155143

156-
return func
144+
raise NotImplementedError(
145+
f"function is not implemented for this dtype: "
146+
f"[how->{how},dtype->{dtype_str}]"
147+
)
157148

158149

159150
class BaseGrouper:
@@ -475,25 +466,24 @@ def _get_cython_func_and_vals(
475466
func : callable
476467
values : np.ndarray
477468
"""
478-
try:
479-
func = _get_cython_function(kind, how, values.dtype, is_numeric)
480-
except NotImplementedError:
469+
if how in ["median", "cumprod"]:
470+
# these two only have float64 implementations
481471
if is_numeric:
482-
try:
483-
values = ensure_float64(values)
484-
except TypeError:
485-
if lib.infer_dtype(values, skipna=False) == "complex":
486-
values = values.astype(complex)
487-
else:
488-
raise
489-
func = _get_cython_function(kind, how, values.dtype, is_numeric)
472+
values = ensure_float64(values)
490473
else:
491-
raise
492-
else:
493-
if values.dtype.kind in ["i", "u"]:
494-
if how in ["ohlc"]:
495-
# The output may still include nans, so we have to cast
496-
values = ensure_float64(values)
474+
raise NotImplementedError(
475+
f"function is not implemented for this dtype: "
476+
f"[how->{how},dtype->{values.dtype.name}]"
477+
)
478+
func = getattr(libgroupby, f"group_{how}_float64")
479+
return func, values
480+
481+
func = _get_cython_function(kind, how, values.dtype, is_numeric)
482+
483+
if values.dtype.kind in ["i", "u"]:
484+
if how in ["add", "var", "prod", "mean", "ohlc"]:
485+
# result may still include NaN, so we have to cast
486+
values = ensure_float64(values)
497487

498488
return func, values
499489

@@ -643,10 +633,9 @@ def _cython_operation(
643633
values = ensure_float64(values)
644634
else:
645635
values = ensure_int_or_float(values)
646-
elif is_numeric and not is_complex_dtype(dtype):
647-
values = ensure_float64(values)
648-
else:
649-
values = values.astype(object)
636+
elif is_numeric:
637+
if not is_complex_dtype(dtype):
638+
values = ensure_float64(values)
650639

651640
arity = self._cython_arity.get(how, 1)
652641
ngroups = self.ngroups

pandas/tests/groupby/test_libgroupby.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from pandas._libs.groupby import (
55
group_cumprod_float64,
66
group_cumsum,
7-
group_var_float32,
8-
group_var_float64,
7+
group_var,
98
)
109

1110
from pandas.core.dtypes.common import ensure_platform_int
@@ -102,7 +101,7 @@ def test_group_var_constant(self):
102101
class TestGroupVarFloat64(GroupVarTestMixin):
103102
__test__ = True
104103

105-
algo = staticmethod(group_var_float64)
104+
algo = staticmethod(group_var)
106105
dtype = np.float64
107106
rtol = 1e-5
108107

@@ -124,7 +123,7 @@ def test_group_var_large_inputs(self):
124123
class TestGroupVarFloat32(GroupVarTestMixin):
125124
__test__ = True
126125

127-
algo = staticmethod(group_var_float32)
126+
algo = staticmethod(group_var)
128127
dtype = np.float32
129128
rtol = 1e-2
130129

0 commit comments

Comments
 (0)