@@ -8,7 +8,7 @@ import numpy as np
8
8
cimport numpy as cnp
9
9
from numpy cimport (ndarray,
10
10
int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
11
- uint32_t, uint64_t, float32_t, float64_t)
11
+ uint32_t, uint64_t, float32_t, float64_t, complex64_t, complex128_t )
12
12
cnp.import_array()
13
13
14
14
@@ -421,30 +421,38 @@ def group_any_all(uint8_t[:] out,
421
421
if values[i] == flag_val:
422
422
out[lab] = flag_val
423
423
424
+
424
425
# ----------------------------------------------------------------------
425
426
# group_add, group_prod, group_var, group_mean, group_ohlc
426
427
# ----------------------------------------------------------------------
427
428
429
+ ctypedef fused complexfloating_t:
430
+ float64_t
431
+ float32_t
432
+ complex64_t
433
+ complex128_t
434
+
428
435
429
436
@ cython.wraparound (False )
430
437
@ cython.boundscheck (False )
431
- def _group_add (floating [:, :] out ,
438
+ def _group_add (complexfloating_t [:, :] out ,
432
439
int64_t[:] counts ,
433
- floating [:, :] values ,
440
+ complexfloating_t [:, :] values ,
434
441
const int64_t[:] labels ,
435
442
Py_ssize_t min_count = 0 ):
436
443
"""
437
444
Only aggregates on axis=0
438
445
"""
439
446
cdef:
440
447
Py_ssize_t i, j, N, K, lab, ncounts = len (counts)
441
- floating val, count
442
- floating[:, :] sumx, nobs
448
+ complexfloating_t val, count
449
+ complexfloating_t[:, :] sumx
450
+ int64_t[:, :] nobs
443
451
444
452
if len (values) != len (labels):
445
453
raise ValueError (" len(index) != len(labels)" )
446
454
447
- nobs = np.zeros_like( out)
455
+ nobs = np.zeros(( len ( out), out.shape[ 1 ]), dtype = np.int64 )
448
456
sumx = np.zeros_like(out)
449
457
450
458
N, K = (< object > values).shape
@@ -462,7 +470,12 @@ def _group_add(floating[:, :] out,
462
470
# not nan
463
471
if val == val:
464
472
nobs[lab, j] += 1
465
- sumx[lab, j] += val
473
+ if (complexfloating_t is complex64_t or
474
+ complexfloating_t is complex128_t):
475
+ # clang errors if we use += with these dtypes
476
+ sumx[lab, j] = sumx[lab, j] + val
477
+ else :
478
+ sumx[lab, j] += val
466
479
467
480
for i in range (ncounts):
468
481
for j in range (K):
@@ -472,8 +485,10 @@ def _group_add(floating[:, :] out,
472
485
out[i, j] = sumx[i, j]
473
486
474
487
475
- group_add_float32 = _group_add[' float' ]
476
- group_add_float64 = _group_add[' double' ]
488
+ group_add_float32 = _group_add[' float32_t' ]
489
+ group_add_float64 = _group_add[' float64_t' ]
490
+ group_add_complex64 = _group_add[' float complex' ]
491
+ group_add_complex128 = _group_add[' double complex' ]
477
492
478
493
479
494
@ cython.wraparound (False )
0 commit comments