@@ -41,6 +41,7 @@ from pandas._libs.algos import (
41
41
ensure_platform_int,
42
42
groupsort_indexer,
43
43
rank_1d,
44
+ take_2d_axis1_bool_bool,
44
45
take_2d_axis1_float64_float64,
45
46
)
46
47
@@ -64,11 +65,48 @@ cdef enum InterpolationEnumType:
64
65
INTERPOLATION_MIDPOINT
65
66
66
67
67
- cdef inline float64_t median_linear (float64_t* a, int n) nogil:
68
+ cdef inline float64_t median_linear_mask (float64_t* a, int n, uint8_t * mask ) nogil:
68
69
cdef:
69
70
int i, j, na_count = 0
71
+ float64_t* tmp
70
72
float64_t result
73
+
74
+ if n == 0 :
75
+ return NaN
76
+
77
+ # count NAs
78
+ for i in range (n):
79
+ if mask[i]:
80
+ na_count += 1
81
+
82
+ if na_count:
83
+ if na_count == n:
84
+ return NaN
85
+
86
+ tmp = < float64_t* > malloc((n - na_count) * sizeof(float64_t))
87
+
88
+ j = 0
89
+ for i in range (n):
90
+ if not mask[i]:
91
+ tmp[j] = a[i]
92
+ j += 1
93
+
94
+ a = tmp
95
+ n -= na_count
96
+
97
+ result = calc_median_linear(a, n, na_count)
98
+
99
+ if na_count:
100
+ free(a)
101
+
102
+ return result
103
+
104
+
105
+ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
106
+ cdef:
107
+ int i, j, na_count = 0
71
108
float64_t* tmp
109
+ float64_t result
72
110
73
111
if n == 0 :
74
112
return NaN
@@ -93,18 +131,34 @@ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
93
131
a = tmp
94
132
n -= na_count
95
133
134
+ result = calc_median_linear(a, n, na_count)
135
+
136
+ if na_count:
137
+ free(a)
138
+
139
+ return result
140
+
141
+
142
+ cdef inline float64_t calc_median_linear(float64_t* a, int n, int na_count) nogil:
143
+ cdef:
144
+ float64_t result
145
+
96
146
if n % 2 :
97
147
result = kth_smallest_c(a, n // 2 , n)
98
148
else :
99
149
result = (kth_smallest_c(a, n // 2 , n) +
100
150
kth_smallest_c(a, n // 2 - 1 , n)) / 2
101
151
102
- if na_count:
103
- free(a)
104
-
105
152
return result
106
153
107
154
155
+ ctypedef fused int64float_t:
156
+ int64_t
157
+ uint64_t
158
+ float32_t
159
+ float64_t
160
+
161
+
108
162
@ cython.boundscheck (False )
109
163
@ cython.wraparound (False )
110
164
def group_median_float64 (
@@ -113,6 +167,8 @@ def group_median_float64(
113
167
ndarray[float64_t , ndim = 2 ] values,
114
168
ndarray[intp_t] labels ,
115
169
Py_ssize_t min_count = - 1 ,
170
+ const uint8_t[:, :] mask = None ,
171
+ uint8_t[:, ::1] result_mask = None ,
116
172
) -> None:
117
173
"""
118
174
Only aggregates on axis = 0
@@ -121,8 +177,12 @@ def group_median_float64(
121
177
Py_ssize_t i , j , N , K , ngroups , size
122
178
ndarray[intp_t] _counts
123
179
ndarray[float64_t , ndim = 2 ] data
180
+ ndarray[uint8_t , ndim = 2 ] data_mask
124
181
ndarray[intp_t] indexer
125
182
float64_t* ptr
183
+ uint8_t* ptr_mask
184
+ float64_t result
185
+ bint uses_mask = mask is not None
126
186
127
187
assert min_count == -1, "'min_count' only used in sum and prod"
128
188
@@ -137,15 +197,38 @@ def group_median_float64(
137
197
138
198
take_2d_axis1_float64_float64(values.T , indexer , out = data)
139
199
140
- with nogil:
200
+ if uses_mask:
201
+ data_mask = np.empty((K, N), dtype = np.uint8)
202
+ ptr_mask = < uint8_t * > cnp.PyArray_DATA(data_mask)
203
+
204
+ take_2d_axis1_bool_bool(mask.T , indexer , out = data_mask, fill_value = 1 )
141
205
142
- for i in range(K ):
143
- # exclude NA group
144
- ptr += _counts[0 ]
145
- for j in range (ngroups):
146
- size = _counts[j + 1 ]
147
- out[j, i] = median_linear(ptr, size)
148
- ptr += size
206
+ with nogil:
207
+
208
+ for i in range(K ):
209
+ # exclude NA group
210
+ ptr += _counts[0 ]
211
+ ptr_mask += _counts[0 ]
212
+
213
+ for j in range (ngroups):
214
+ size = _counts[j + 1 ]
215
+ result = median_linear_mask(ptr, size, ptr_mask)
216
+ out[j, i] = result
217
+
218
+ if result != result:
219
+ result_mask[j, i] = 1
220
+ ptr += size
221
+ ptr_mask += size
222
+
223
+ else :
224
+ with nogil:
225
+ for i in range (K):
226
+ # exclude NA group
227
+ ptr += _counts[0 ]
228
+ for j in range (ngroups):
229
+ size = _counts[j + 1 ]
230
+ out[j, i] = median_linear(ptr, size)
231
+ ptr += size
149
232
150
233
151
234
@ cython.boundscheck (False )
@@ -206,13 +289,6 @@ def group_cumprod_float64(
206
289
accum[lab, j] = NaN
207
290
208
291
209
- ctypedef fused int64float_t:
210
- int64_t
211
- uint64_t
212
- float32_t
213
- float64_t
214
-
215
-
216
292
@ cython.boundscheck (False )
217
293
@ cython.wraparound (False )
218
294
def group_cumsum (
0 commit comments