Skip to content

Commit 11e3dc2

Browse files
authored
PERF: Fix groupby skipna performance (#60871)
1 parent 8f802cd commit 11e3dc2

File tree

2 files changed

+64
-35
lines changed

2 files changed

+64
-35
lines changed

pandas/_libs/groupby.pyx

+63-34
Original file line numberDiff line numberDiff line change
@@ -753,16 +753,20 @@ def group_sum(
753753

754754
if uses_mask:
755755
isna_entry = mask[i, j]
756-
isna_result = result_mask[lab, j]
757756
else:
758757
isna_entry = _treat_as_na(val, is_datetimelike)
759-
isna_result = _treat_as_na(sumx[lab, j], is_datetimelike)
760758

761-
if not skipna and isna_result:
762-
# If sum is already NA, don't add to it. This is important for
763-
# datetimelikebecause adding a value to NPY_NAT may not result
764-
# in a NPY_NAT
765-
continue
759+
if not skipna:
760+
if uses_mask:
761+
isna_result = result_mask[lab, j]
762+
else:
763+
isna_result = _treat_as_na(sumx[lab, j], is_datetimelike)
764+
765+
if isna_result:
766+
# If sum is already NA, don't add to it. This is important for
767+
# datetimelikebecause adding a value to NPY_NAT may not result
768+
# in a NPY_NAT
769+
continue
766770

767771
if not isna_entry:
768772
nobs[lab, j] += 1
@@ -845,14 +849,18 @@ def group_prod(
845849

846850
if uses_mask:
847851
isna_entry = mask[i, j]
848-
isna_result = result_mask[lab, j]
849852
else:
850853
isna_entry = _treat_as_na(val, False)
851-
isna_result = _treat_as_na(prodx[lab, j], False)
852854

853-
if not skipna and isna_result:
854-
# If prod is already NA, no need to update it
855-
continue
855+
if not skipna:
856+
if uses_mask:
857+
isna_result = result_mask[lab, j]
858+
else:
859+
isna_result = _treat_as_na(prodx[lab, j], False)
860+
861+
if isna_result:
862+
# If prod is already NA, no need to update it
863+
continue
856864

857865
if not isna_entry:
858866
nobs[lab, j] += 1
@@ -919,22 +927,30 @@ def group_var(
919927

920928
if uses_mask:
921929
isna_entry = mask[i, j]
922-
isna_result = result_mask[lab, j]
923930
elif is_datetimelike:
924931
# With group_var, we cannot just use _treat_as_na bc
925932
# datetimelike dtypes get cast to float64 instead of
926933
# to int64.
927934
isna_entry = val == NPY_NAT
928-
isna_result = out[lab, j] == NPY_NAT
929935
else:
930936
isna_entry = _treat_as_na(val, is_datetimelike)
931-
isna_result = _treat_as_na(out[lab, j], is_datetimelike)
932937

933-
if not skipna and isna_result:
934-
# If aggregate is already NA, don't add to it. This is important for
935-
# datetimelike because adding a value to NPY_NAT may not result
936-
# in a NPY_NAT
937-
continue
938+
if not skipna:
939+
if uses_mask:
940+
isna_result = result_mask[lab, j]
941+
elif is_datetimelike:
942+
# With group_var, we cannot just use _treat_as_na bc
943+
# datetimelike dtypes get cast to float64 instead of
944+
# to int64.
945+
isna_result = out[lab, j] == NPY_NAT
946+
else:
947+
isna_result = _treat_as_na(out[lab, j], is_datetimelike)
948+
949+
if isna_result:
950+
# If aggregate is already NA, don't add to it. This is
951+
# important for datetimelike because adding a value to NPY_NAT
952+
# may not result in a NPY_NAT
953+
continue
938954

939955
if not isna_entry:
940956
nobs[lab, j] += 1
@@ -1232,22 +1248,30 @@ def group_mean(
12321248

12331249
if uses_mask:
12341250
isna_entry = mask[i, j]
1235-
isna_result = result_mask[lab, j]
12361251
elif is_datetimelike:
12371252
# With group_mean, we cannot just use _treat_as_na bc
12381253
# datetimelike dtypes get cast to float64 instead of
12391254
# to int64.
12401255
isna_entry = val == NPY_NAT
1241-
isna_result = sumx[lab, j] == NPY_NAT
12421256
else:
12431257
isna_entry = _treat_as_na(val, is_datetimelike)
1244-
isna_result = _treat_as_na(sumx[lab, j], is_datetimelike)
12451258

1246-
if not skipna and isna_result:
1247-
# If sum is already NA, don't add to it. This is important for
1248-
# datetimelike because adding a value to NPY_NAT may not result
1249-
# in NPY_NAT
1250-
continue
1259+
if not skipna:
1260+
if uses_mask:
1261+
isna_result = result_mask[lab, j]
1262+
elif is_datetimelike:
1263+
# With group_mean, we cannot just use _treat_as_na bc
1264+
# datetimelike dtypes get cast to float64 instead of
1265+
# to int64.
1266+
isna_result = sumx[lab, j] == NPY_NAT
1267+
else:
1268+
isna_result = _treat_as_na(sumx[lab, j], is_datetimelike)
1269+
1270+
if isna_result:
1271+
# If sum is already NA, don't add to it. This is important for
1272+
# datetimelike because adding a value to NPY_NAT may not result
1273+
# in NPY_NAT
1274+
continue
12511275

12521276
if not isna_entry:
12531277
nobs[lab, j] += 1
@@ -1909,15 +1933,20 @@ cdef group_min_max(
19091933

19101934
if uses_mask:
19111935
isna_entry = mask[i, j]
1912-
isna_result = result_mask[lab, j]
19131936
else:
19141937
isna_entry = _treat_as_na(val, is_datetimelike)
1915-
isna_result = _treat_as_na(group_min_or_max[lab, j],
1916-
is_datetimelike)
19171938

1918-
if not skipna and isna_result:
1919-
# If current min/max is already NA, it will always be NA
1920-
continue
1939+
if not skipna:
1940+
if uses_mask:
1941+
isna_result = result_mask[lab, j]
1942+
else:
1943+
isna_result = _treat_as_na(
1944+
group_min_or_max[lab, j], is_datetimelike
1945+
)
1946+
1947+
if isna_result:
1948+
# If current min/max is already NA, it will always be NA
1949+
continue
19211950

19221951
if not isna_entry:
19231952
nobs[lab, j] += 1

pandas/core/_numba/kernels/min_max_.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def grouped_min_max(
9898
for i in range(N):
9999
lab = labels[i]
100100
val = values[i]
101-
if lab < 0 or (nobs[lab] >= 1 and np.isnan(output[lab])):
101+
if lab < 0 or (not skipna and nobs[lab] >= 1 and np.isnan(output[lab])):
102102
continue
103103

104104
if values.dtype.kind == "i" or not np.isnan(val):

0 commit comments

Comments
 (0)