Skip to content

Commit ae61ccd

Browse files
authored
REF: use _treat_as_na more (#51067)
1 parent a254bcf commit ae61ccd

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

pandas/_libs/groupby.pyx

+24-29
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,8 @@ def group_cumprod(
289289

290290
if uses_mask:
291291
isna_entry = mask[i, j]
292-
elif int64float_t is float64_t or int64float_t is float32_t:
293-
isna_entry = val != val
294292
else:
295-
isna_entry = False
293+
isna_entry = _treat_as_na(val, False)
296294

297295
if not isna_entry:
298296
isna_prev = accum_mask[lab, j]
@@ -737,23 +735,10 @@ def group_sum(
737735
for j in range(K):
738736
val = values[i, j]
739737

740-
# not nan
741-
# With dt64/td64 values, values have been cast to float64
742-
# instead if int64 for group_sum, but the logic
743-
# is otherwise the same as in _treat_as_na
744738
if uses_mask:
745739
isna_entry = mask[i, j]
746-
elif (
747-
sum_t is float32_t
748-
or sum_t is float64_t
749-
or sum_t is complex64_t
750-
):
751-
# avoid warnings because of equality comparison
752-
isna_entry = not val == val
753-
elif sum_t is int64_t and is_datetimelike and val == NPY_NAT:
754-
isna_entry = True
755740
else:
756-
isna_entry = False
741+
isna_entry = _treat_as_na(val, is_datetimelike)
757742

758743
if not isna_entry:
759744
nobs[lab, j] += 1
@@ -831,10 +816,8 @@ def group_prod(
831816

832817
if uses_mask:
833818
isna_entry = mask[i, j]
834-
elif int64float_t is float32_t or int64float_t is float64_t:
835-
isna_entry = not val == val
836819
else:
837-
isna_entry = False
820+
isna_entry = _treat_as_na(val, False)
838821

839822
if not isna_entry:
840823
nobs[lab, j] += 1
@@ -906,7 +889,7 @@ def group_var(
906889
if uses_mask:
907890
isna_entry = mask[i, j]
908891
else:
909-
isna_entry = not val == val
892+
isna_entry = _treat_as_na(val, False)
910893

911894
if not isna_entry:
912895
nobs[lab, j] += 1
@@ -1008,9 +991,12 @@ def group_mean(
1008991
if uses_mask:
1009992
isna_entry = mask[i, j]
1010993
elif is_datetimelike:
994+
# With group_mean, we cannot just use _treat_as_na bc
995+
# datetimelike dtypes get cast to float64 instead of
996+
# to int64.
1011997
isna_entry = val == NPY_NAT
1012998
else:
1013-
isna_entry = not val == val
999+
isna_entry = _treat_as_na(val, is_datetimelike)
10141000

10151001
if not isna_entry:
10161002
nobs[lab, j] += 1
@@ -1086,10 +1072,8 @@ def group_ohlc(
10861072

10871073
if uses_mask:
10881074
isna_entry = mask[i, 0]
1089-
elif int64float_t is float32_t or int64float_t is float64_t:
1090-
isna_entry = val != val
10911075
else:
1092-
isna_entry = False
1076+
isna_entry = _treat_as_na(val, False)
10931077

10941078
if isna_entry:
10951079
continue
@@ -1231,15 +1215,26 @@ def group_quantile(
12311215
# group_nth, group_last, group_rank
12321216
# ----------------------------------------------------------------------
12331217

1234-
cdef bint _treat_as_na(numeric_object_t val, bint is_datetimelike) nogil:
1235-
if numeric_object_t is object:
1218+
ctypedef fused numeric_object_complex_t:
1219+
numeric_object_t
1220+
complex64_t
1221+
complex128_t
1222+
1223+
1224+
cdef bint _treat_as_na(numeric_object_complex_t val, bint is_datetimelike) nogil:
1225+
if numeric_object_complex_t is object:
12361226
# Should never be used, but we need to avoid the `val != val` below
12371227
# or else cython will raise about gil acquisition.
12381228
raise NotImplementedError
12391229

1240-
elif numeric_object_t is int64_t:
1230+
elif numeric_object_complex_t is int64_t:
12411231
return is_datetimelike and val == NPY_NAT
1242-
elif numeric_object_t is float32_t or numeric_object_t is float64_t:
1232+
elif (
1233+
numeric_object_complex_t is float32_t
1234+
or numeric_object_complex_t is float64_t
1235+
or numeric_object_complex_t is complex64_t
1236+
or numeric_object_complex_t is complex128_t
1237+
):
12431238
return val != val
12441239
else:
12451240
# non-datetimelike integer

0 commit comments

Comments
 (0)