@@ -759,13 +759,16 @@ def group_var(
759
759
const intp_t[::1] labels ,
760
760
Py_ssize_t min_count = - 1 ,
761
761
int64_t ddof = 1 ,
762
+ const uint8_t[:, ::1] mask = None ,
763
+ uint8_t[:, ::1] result_mask = None ,
762
764
) -> None:
763
765
cdef:
764
766
Py_ssize_t i , j , N , K , lab , ncounts = len (counts)
765
767
floating val , ct , oldmean
766
768
floating[:, ::1] mean
767
769
int64_t[:, ::1] nobs
768
770
Py_ssize_t len_values = len (values), len_labels = len (labels)
771
+ bint isna_entry , uses_mask = not mask is None
769
772
770
773
assert min_count == -1, "'min_count' only used in sum and prod"
771
774
@@ -790,8 +793,12 @@ def group_var(
790
793
for j in range (K):
791
794
val = values[i, j]
792
795
793
- # not nan
794
- if val == val:
796
+ if uses_mask:
797
+ isna_entry = mask[i, j]
798
+ else :
799
+ isna_entry = not val == val
800
+
801
+ if not isna_entry:
795
802
nobs[lab, j] += 1
796
803
oldmean = mean[lab, j]
797
804
mean[lab, j] += (val - oldmean) / nobs[lab, j]
@@ -801,7 +808,10 @@ def group_var(
801
808
for j in range (K):
802
809
ct = nobs[i, j]
803
810
if ct <= ddof:
804
- out[i, j] = NAN
811
+ if uses_mask:
812
+ result_mask[i, j] = True
813
+ else :
814
+ out[i, j] = NAN
805
815
else :
806
816
out[i, j] /= (ct - ddof)
807
817
@@ -839,9 +849,9 @@ def group_mean(
839
849
is_datetimelike : bool
840
850
True if `values` contains datetime-like entries.
841
851
mask : ndarray[bool , ndim = 2 ], optional
842
- Not used .
852
+ Mask of the input values .
843
853
result_mask : ndarray[bool , ndim = 2 ], optional
844
- Not used.
854
+ Mask of the out array
845
855
846
856
Notes
847
857
-----
@@ -855,6 +865,7 @@ def group_mean(
855
865
mean_t[:, ::1] sumx , compensation
856
866
int64_t[:, ::1] nobs
857
867
Py_ssize_t len_values = len (values), len_labels = len (labels)
868
+ bint isna_entry , uses_mask = not mask is None
858
869
859
870
assert min_count == -1, "'min_count' only used in sum and prod"
860
871
@@ -867,7 +878,12 @@ def group_mean(
867
878
compensation = np.zeros((< object > out).shape, dtype = (< object > out).base.dtype)
868
879
869
880
N , K = (< object > values).shape
870
- nan_val = NPY_NAT if is_datetimelike else NAN
881
+ if uses_mask:
882
+ nan_val = 0
883
+ elif is_datetimelike:
884
+ nan_val = NPY_NAT
885
+ else:
886
+ nan_val = NAN
871
887
872
888
with nogil:
873
889
for i in range(N ):
@@ -878,8 +894,15 @@ def group_mean(
878
894
counts[lab] += 1
879
895
for j in range (K):
880
896
val = values[i, j]
881
- # not nan
882
- if val == val and not (is_datetimelike and val == NPY_NAT):
897
+
898
+ if uses_mask:
899
+ isna_entry = mask[i, j]
900
+ elif is_datetimelike:
901
+ isna_entry = val == NPY_NAT
902
+ else :
903
+ isna_entry = not val == val
904
+
905
+ if not isna_entry:
883
906
nobs[lab, j] += 1
884
907
y = val - compensation[lab, j]
885
908
t = sumx[lab, j] + y
@@ -890,7 +913,12 @@ def group_mean(
890
913
for j in range (K):
891
914
count = nobs[i, j]
892
915
if nobs[i, j] == 0 :
893
- out[i, j] = nan_val
916
+
917
+ if uses_mask:
918
+ result_mask[i, j] = True
919
+ else :
920
+ out[i, j] = nan_val
921
+
894
922
else :
895
923
out[i, j] = sumx[i, j] / count
896
924
0 commit comments