@@ -289,10 +289,8 @@ def group_cumprod(
289
289
290
290
if uses_mask:
291
291
isna_entry = mask[i, j]
292
- elif int64float_t is float64_t or int64float_t is float32_t:
293
- isna_entry = val != val
294
292
else :
295
- isna_entry = False
293
+ isna_entry = _treat_as_na(val, False )
296
294
297
295
if not isna_entry:
298
296
isna_prev = accum_mask[lab, j]
@@ -737,23 +735,10 @@ def group_sum(
737
735
for j in range (K):
738
736
val = values[i, j]
739
737
740
- # not nan
741
- # With dt64/td64 values, values have been cast to float64
742
- # instead if int64 for group_sum, but the logic
743
- # is otherwise the same as in _treat_as_na
744
738
if uses_mask:
745
739
isna_entry = mask[i, j]
746
- elif (
747
- sum_t is float32_t
748
- or sum_t is float64_t
749
- or sum_t is complex64_t
750
- ):
751
- # avoid warnings because of equality comparison
752
- isna_entry = not val == val
753
- elif sum_t is int64_t and is_datetimelike and val == NPY_NAT:
754
- isna_entry = True
755
740
else :
756
- isna_entry = False
741
+ isna_entry = _treat_as_na(val, is_datetimelike)
757
742
758
743
if not isna_entry:
759
744
nobs[lab, j] += 1
@@ -831,10 +816,8 @@ def group_prod(
831
816
832
817
if uses_mask:
833
818
isna_entry = mask[i, j]
834
- elif int64float_t is float32_t or int64float_t is float64_t:
835
- isna_entry = not val == val
836
819
else :
837
- isna_entry = False
820
+ isna_entry = _treat_as_na(val, False )
838
821
839
822
if not isna_entry:
840
823
nobs[lab, j] += 1
@@ -906,7 +889,7 @@ def group_var(
906
889
if uses_mask:
907
890
isna_entry = mask[i, j]
908
891
else :
909
- isna_entry = not val == val
892
+ isna_entry = _treat_as_na( val, False )
910
893
911
894
if not isna_entry:
912
895
nobs[lab, j] += 1
@@ -1008,9 +991,12 @@ def group_mean(
1008
991
if uses_mask:
1009
992
isna_entry = mask[i, j]
1010
993
elif is_datetimelike:
994
+ # With group_mean, we cannot just use _treat_as_na bc
995
+ # datetimelike dtypes get cast to float64 instead of
996
+ # to int64.
1011
997
isna_entry = val == NPY_NAT
1012
998
else :
1013
- isna_entry = not val == val
999
+ isna_entry = _treat_as_na( val, is_datetimelike)
1014
1000
1015
1001
if not isna_entry:
1016
1002
nobs[lab, j] += 1
@@ -1086,10 +1072,8 @@ def group_ohlc(
1086
1072
1087
1073
if uses_mask:
1088
1074
isna_entry = mask[i, 0 ]
1089
- elif int64float_t is float32_t or int64float_t is float64_t:
1090
- isna_entry = val != val
1091
1075
else :
1092
- isna_entry = False
1076
+ isna_entry = _treat_as_na(val, False )
1093
1077
1094
1078
if isna_entry:
1095
1079
continue
@@ -1231,15 +1215,26 @@ def group_quantile(
1231
1215
# group_nth, group_last, group_rank
1232
1216
# ----------------------------------------------------------------------
1233
1217
1234
- cdef bint _treat_as_na(numeric_object_t val, bint is_datetimelike) nogil:
1235
- if numeric_object_t is object :
1218
+ ctypedef fused numeric_object_complex_t:
1219
+ numeric_object_t
1220
+ complex64_t
1221
+ complex128_t
1222
+
1223
+
1224
+ cdef bint _treat_as_na(numeric_object_complex_t val, bint is_datetimelike) nogil:
1225
+ if numeric_object_complex_t is object :
1236
1226
# Should never be used, but we need to avoid the `val != val` below
1237
1227
# or else cython will raise about gil acquisition.
1238
1228
raise NotImplementedError
1239
1229
1240
- elif numeric_object_t is int64_t:
1230
+ elif numeric_object_complex_t is int64_t:
1241
1231
return is_datetimelike and val == NPY_NAT
1242
- elif numeric_object_t is float32_t or numeric_object_t is float64_t:
1232
+ elif (
1233
+ numeric_object_complex_t is float32_t
1234
+ or numeric_object_complex_t is float64_t
1235
+ or numeric_object_complex_t is complex64_t
1236
+ or numeric_object_complex_t is complex128_t
1237
+ ):
1243
1238
return val != val
1244
1239
else :
1245
1240
# non-datetimelike integer
0 commit comments