Skip to content

Commit 6d35836

Browse files
jbrockmendeljreback
authored andcommitted
Stop catching TypeError in groupby methods (#29060)
1 parent 509eb14 commit 6d35836

File tree

3 files changed

+41
-21
lines changed

3 files changed

+41
-21
lines changed

pandas/_libs/groupby.pyx

+24-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import numpy as np
88
cimport numpy as cnp
99
from numpy cimport (ndarray,
1010
int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
11-
uint32_t, uint64_t, float32_t, float64_t)
11+
uint32_t, uint64_t, float32_t, float64_t, complex64_t, complex128_t)
1212
cnp.import_array()
1313

1414

@@ -421,30 +421,38 @@ def group_any_all(uint8_t[:] out,
421421
if values[i] == flag_val:
422422
out[lab] = flag_val
423423

424+
424425
# ----------------------------------------------------------------------
425426
# group_add, group_prod, group_var, group_mean, group_ohlc
426427
# ----------------------------------------------------------------------
427428

429+
ctypedef fused complexfloating_t:
430+
float64_t
431+
float32_t
432+
complex64_t
433+
complex128_t
434+
428435

429436
@cython.wraparound(False)
430437
@cython.boundscheck(False)
431-
def _group_add(floating[:, :] out,
438+
def _group_add(complexfloating_t[:, :] out,
432439
int64_t[:] counts,
433-
floating[:, :] values,
440+
complexfloating_t[:, :] values,
434441
const int64_t[:] labels,
435442
Py_ssize_t min_count=0):
436443
"""
437444
Only aggregates on axis=0
438445
"""
439446
cdef:
440447
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
441-
floating val, count
442-
floating[:, :] sumx, nobs
448+
complexfloating_t val, count
449+
complexfloating_t[:, :] sumx
450+
int64_t[:, :] nobs
443451

444452
if len(values) != len(labels):
445453
raise ValueError("len(index) != len(labels)")
446454

447-
nobs = np.zeros_like(out)
455+
nobs = np.zeros((len(out), out.shape[1]), dtype=np.int64)
448456
sumx = np.zeros_like(out)
449457

450458
N, K = (<object>values).shape
@@ -462,7 +470,12 @@ def _group_add(floating[:, :] out,
462470
# not nan
463471
if val == val:
464472
nobs[lab, j] += 1
465-
sumx[lab, j] += val
473+
if (complexfloating_t is complex64_t or
474+
complexfloating_t is complex128_t):
475+
# clang errors if we use += with these dtypes
476+
sumx[lab, j] = sumx[lab, j] + val
477+
else:
478+
sumx[lab, j] += val
466479

467480
for i in range(ncounts):
468481
for j in range(K):
@@ -472,8 +485,10 @@ def _group_add(floating[:, :] out,
472485
out[i, j] = sumx[i, j]
473486

474487

475-
group_add_float32 = _group_add['float']
476-
group_add_float64 = _group_add['double']
488+
group_add_float32 = _group_add['float32_t']
489+
group_add_float64 = _group_add['float64_t']
490+
group_add_complex64 = _group_add['float complex']
491+
group_add_complex128 = _group_add['double complex']
477492

478493

479494
@cython.wraparound(False)

pandas/core/groupby/groupby.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1340,19 +1340,18 @@ def f(self, **kwargs):
13401340
# try a cython aggregation if we can
13411341
try:
13421342
return self._cython_agg_general(alias, alt=npfunc, **kwargs)
1343-
except AssertionError:
1344-
raise
13451343
except DataError:
13461344
pass
1347-
except (TypeError, NotImplementedError):
1348-
# TODO:
1349-
# - TypeError: this is reached via test_groupby_complex
1350-
# and can be fixed by implementing _group_add for
1351-
# complex dtypes
1352-
# - NotImplementedError: reached in test_max_nan_bug,
1353-
# raised in _get_cython_function and should probably
1354-
# be handled inside _cython_agg_blocks
1355-
pass
1345+
except NotImplementedError as err:
1346+
if "function is not implemented for this dtype" in str(err):
1347+
# raised in _get_cython_function, in some cases can
1348+
# be trimmed by implementing cython funcs for more dtypes
1349+
pass
1350+
elif "decimal does not support skipna=True" in str(err):
1351+
# FIXME: kludge for test_decimal:test_in_numeric_groupby
1352+
pass
1353+
else:
1354+
raise
13561355

13571356
# apply a non-cython aggregation
13581357
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))

pandas/core/groupby/ops.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,13 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1, **kwargs):
526526
func = self._get_cython_function(kind, how, values, is_numeric)
527527
except NotImplementedError:
528528
if is_numeric:
529-
values = ensure_float64(values)
529+
try:
530+
values = ensure_float64(values)
531+
except TypeError:
532+
if lib.infer_dtype(values, skipna=False) == "complex":
533+
values = values.astype(complex)
534+
else:
535+
raise
530536
func = self._get_cython_function(kind, how, values, is_numeric)
531537
else:
532538
raise

0 commit comments

Comments
 (0)