Skip to content

Commit c4edfc5

Browse files
phoflnoatamir
authored andcommitted
ENH: Suppport masks in median groupby algo (pandas-dev#48387)
* ENH: Suppport masks in median groupby algo * Avoid float cast * Out dtype * Add cast * Revert algos * Add whatsnew * Deduplicate * Fix * Add type hints * Move free
1 parent b72c5a7 commit c4edfc5

File tree

5 files changed

+102
-20
lines changed

5 files changed

+102
-20
lines changed

asv_bench/benchmarks/groupby.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from pandas import (
8+
NA,
89
Categorical,
910
DataFrame,
1011
Index,
@@ -592,6 +593,7 @@ def setup(self, dtype, method):
592593
columns=list("abcdefghij"),
593594
dtype=dtype,
594595
)
596+
df.loc[list(range(1, N, 5)), list("abcdefghij")] = NA
595597
df["key"] = np.random.randint(0, 100, size=N)
596598
self.df = df
597599

doc/source/whatsnew/v1.6.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ Deprecations
100100

101101
Performance improvements
102102
~~~~~~~~~~~~~~~~~~~~~~~~
103+
- Performance improvement in :meth:`.GroupBy.median` for nullable dtypes (:issue:`37493`)
103104
- Performance improvement in :meth:`.GroupBy.mean` and :meth:`.GroupBy.var` for extension array dtypes (:issue:`37493`)
104105
- Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`)
105106
-

pandas/_libs/groupby.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ def group_median_float64(
1010
values: np.ndarray, # ndarray[float64_t, ndim=2]
1111
labels: npt.NDArray[np.int64],
1212
min_count: int = ..., # Py_ssize_t
13+
mask: np.ndarray | None = ...,
14+
result_mask: np.ndarray | None = ...,
1315
) -> None: ...
1416
def group_cumprod_float64(
1517
out: np.ndarray, # float64_t[:, ::1]

pandas/_libs/groupby.pyx

+95-19
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ from pandas._libs.algos import (
4141
ensure_platform_int,
4242
groupsort_indexer,
4343
rank_1d,
44+
take_2d_axis1_bool_bool,
4445
take_2d_axis1_float64_float64,
4546
)
4647

@@ -64,11 +65,48 @@ cdef enum InterpolationEnumType:
6465
INTERPOLATION_MIDPOINT
6566

6667

67-
cdef inline float64_t median_linear(float64_t* a, int n) nogil:
68+
cdef inline float64_t median_linear_mask(float64_t* a, int n, uint8_t* mask) nogil:
6869
cdef:
6970
int i, j, na_count = 0
71+
float64_t* tmp
7072
float64_t result
73+
74+
if n == 0:
75+
return NaN
76+
77+
# count NAs
78+
for i in range(n):
79+
if mask[i]:
80+
na_count += 1
81+
82+
if na_count:
83+
if na_count == n:
84+
return NaN
85+
86+
tmp = <float64_t*>malloc((n - na_count) * sizeof(float64_t))
87+
88+
j = 0
89+
for i in range(n):
90+
if not mask[i]:
91+
tmp[j] = a[i]
92+
j += 1
93+
94+
a = tmp
95+
n -= na_count
96+
97+
result = calc_median_linear(a, n, na_count)
98+
99+
if na_count:
100+
free(a)
101+
102+
return result
103+
104+
105+
cdef inline float64_t median_linear(float64_t* a, int n) nogil:
106+
cdef:
107+
int i, j, na_count = 0
71108
float64_t* tmp
109+
float64_t result
72110

73111
if n == 0:
74112
return NaN
@@ -93,18 +131,34 @@ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
93131
a = tmp
94132
n -= na_count
95133

134+
result = calc_median_linear(a, n, na_count)
135+
136+
if na_count:
137+
free(a)
138+
139+
return result
140+
141+
142+
cdef inline float64_t calc_median_linear(float64_t* a, int n, int na_count) nogil:
143+
cdef:
144+
float64_t result
145+
96146
if n % 2:
97147
result = kth_smallest_c(a, n // 2, n)
98148
else:
99149
result = (kth_smallest_c(a, n // 2, n) +
100150
kth_smallest_c(a, n // 2 - 1, n)) / 2
101151

102-
if na_count:
103-
free(a)
104-
105152
return result
106153

107154

155+
ctypedef fused int64float_t:
156+
int64_t
157+
uint64_t
158+
float32_t
159+
float64_t
160+
161+
108162
@cython.boundscheck(False)
109163
@cython.wraparound(False)
110164
def group_median_float64(
@@ -113,6 +167,8 @@ def group_median_float64(
113167
ndarray[float64_t, ndim=2] values,
114168
ndarray[intp_t] labels,
115169
Py_ssize_t min_count=-1,
170+
const uint8_t[:, :] mask=None,
171+
uint8_t[:, ::1] result_mask=None,
116172
) -> None:
117173
"""
118174
Only aggregates on axis=0
@@ -121,8 +177,12 @@ def group_median_float64(
121177
Py_ssize_t i, j, N, K, ngroups, size
122178
ndarray[intp_t] _counts
123179
ndarray[float64_t, ndim=2] data
180+
ndarray[uint8_t, ndim=2] data_mask
124181
ndarray[intp_t] indexer
125182
float64_t* ptr
183+
uint8_t* ptr_mask
184+
float64_t result
185+
bint uses_mask = mask is not None
126186

127187
assert min_count == -1, "'min_count' only used in sum and prod"
128188

@@ -137,15 +197,38 @@ def group_median_float64(
137197

138198
take_2d_axis1_float64_float64(values.T, indexer, out=data)
139199

140-
with nogil:
200+
if uses_mask:
201+
data_mask = np.empty((K, N), dtype=np.uint8)
202+
ptr_mask = <uint8_t *>cnp.PyArray_DATA(data_mask)
203+
204+
take_2d_axis1_bool_bool(mask.T, indexer, out=data_mask, fill_value=1)
141205

142-
for i in range(K):
143-
# exclude NA group
144-
ptr += _counts[0]
145-
for j in range(ngroups):
146-
size = _counts[j + 1]
147-
out[j, i] = median_linear(ptr, size)
148-
ptr += size
206+
with nogil:
207+
208+
for i in range(K):
209+
# exclude NA group
210+
ptr += _counts[0]
211+
ptr_mask += _counts[0]
212+
213+
for j in range(ngroups):
214+
size = _counts[j + 1]
215+
result = median_linear_mask(ptr, size, ptr_mask)
216+
out[j, i] = result
217+
218+
if result != result:
219+
result_mask[j, i] = 1
220+
ptr += size
221+
ptr_mask += size
222+
223+
else:
224+
with nogil:
225+
for i in range(K):
226+
# exclude NA group
227+
ptr += _counts[0]
228+
for j in range(ngroups):
229+
size = _counts[j + 1]
230+
out[j, i] = median_linear(ptr, size)
231+
ptr += size
149232

150233

151234
@cython.boundscheck(False)
@@ -206,13 +289,6 @@ def group_cumprod_float64(
206289
accum[lab, j] = NaN
207290

208291

209-
ctypedef fused int64float_t:
210-
int64_t
211-
uint64_t
212-
float32_t
213-
float64_t
214-
215-
216292
@cython.boundscheck(False)
217293
@cython.wraparound(False)
218294
def group_cumsum(

pandas/core/groupby/ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
162162
"prod",
163163
"mean",
164164
"var",
165+
"median",
165166
}
166167

167168
_cython_arity = {"ohlc": 4} # OHLC
@@ -600,7 +601,7 @@ def _call_cython_op(
600601
min_count=min_count,
601602
is_datetimelike=is_datetimelike,
602603
)
603-
elif self.how in ["var", "ohlc", "prod"]:
604+
elif self.how in ["var", "ohlc", "prod", "median"]:
604605
func(
605606
result,
606607
counts,

0 commit comments

Comments
 (0)