Skip to content

Commit b07809b

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
REF: use fused types for mode (pandas-dev#46089)
1 parent 10efdc5 commit b07809b

File tree

1 file changed

+39
-85
lines changed

1 file changed

+39
-85
lines changed

pandas/_libs/hashtable_func_helper.pxi.in

+39-85
Original file line numberDiff line numberDiff line change
@@ -240,63 +240,6 @@ cdef ismember_{{dtype}}(const {{dtype}}_t[:] arr, const {{dtype}}_t[:] values):
240240
# Mode Computations
241241
# ----------------------------------------------------------------------
242242

243-
244-
@cython.wraparound(False)
245-
@cython.boundscheck(False)
246-
{{if dtype == 'object'}}
247-
cdef mode_{{dtype}}(ndarray[{{dtype}}] values, bint dropna):
248-
{{else}}
249-
cdef mode_{{dtype}}(const {{dtype}}_t[:] values, bint dropna):
250-
{{endif}}
251-
cdef:
252-
{{if dtype == 'object'}}
253-
ndarray[{{dtype}}] keys
254-
ndarray[{{dtype}}] modes
255-
{{else}}
256-
{{dtype}}_t[:] keys
257-
ndarray[{{dtype}}_t] modes
258-
{{endif}}
259-
int64_t[:] counts
260-
int64_t count, max_count = -1
261-
Py_ssize_t k, j = 0
262-
263-
keys, counts = value_count_{{dtype}}(values, dropna)
264-
265-
{{if dtype == 'object'}}
266-
modes = np.empty(len(keys), dtype=np.object_)
267-
{{else}}
268-
modes = np.empty(len(keys), dtype=np.{{dtype}})
269-
{{endif}}
270-
271-
{{if dtype != 'object'}}
272-
with nogil:
273-
for k in range(len(keys)):
274-
count = counts[k]
275-
if count == max_count:
276-
j += 1
277-
elif count > max_count:
278-
max_count = count
279-
j = 0
280-
else:
281-
continue
282-
283-
modes[j] = keys[k]
284-
{{else}}
285-
for k in range(len(keys)):
286-
count = counts[k]
287-
if count == max_count:
288-
j += 1
289-
elif count > max_count:
290-
max_count = count
291-
j = 0
292-
else:
293-
continue
294-
295-
modes[j] = keys[k]
296-
{{endif}}
297-
298-
return modes[:j + 1]
299-
300243
{{endfor}}
301244

302245

@@ -414,40 +357,51 @@ cpdef ismember(ndarray[htfunc_t] arr, ndarray[htfunc_t] values):
414357
raise TypeError(values.dtype)
415358

416359

417-
cpdef mode(ndarray[htfunc_t] values, bint dropna):
418-
if htfunc_t is object:
419-
return mode_object(values, dropna)
360+
@cython.wraparound(False)
361+
@cython.boundscheck(False)
362+
def mode(ndarray[htfunc_t] values, bint dropna):
363+
# TODO(cython3): use const htfunct_t[:]
420364

421-
elif htfunc_t is int8_t:
422-
return mode_int8(values, dropna)
423-
elif htfunc_t is int16_t:
424-
return mode_int16(values, dropna)
425-
elif htfunc_t is int32_t:
426-
return mode_int32(values, dropna)
427-
elif htfunc_t is int64_t:
428-
return mode_int64(values, dropna)
365+
cdef:
366+
ndarray[htfunc_t] keys
367+
ndarray[htfunc_t] modes
429368

430-
elif htfunc_t is uint8_t:
431-
return mode_uint8(values, dropna)
432-
elif htfunc_t is uint16_t:
433-
return mode_uint16(values, dropna)
434-
elif htfunc_t is uint32_t:
435-
return mode_uint32(values, dropna)
436-
elif htfunc_t is uint64_t:
437-
return mode_uint64(values, dropna)
369+
int64_t[:] counts
370+
int64_t count, max_count = -1
371+
Py_ssize_t nkeys, k, j = 0
438372

439-
elif htfunc_t is float64_t:
440-
return mode_float64(values, dropna)
441-
elif htfunc_t is float32_t:
442-
return mode_float32(values, dropna)
373+
keys, counts = value_count(values, dropna)
374+
nkeys = len(keys)
443375

444-
elif htfunc_t is complex128_t:
445-
return mode_complex128(values, dropna)
446-
elif htfunc_t is complex64_t:
447-
return mode_complex64(values, dropna)
376+
modes = np.empty(nkeys, dtype=values.dtype)
448377

378+
if htfunc_t is not object:
379+
with nogil:
380+
for k in range(nkeys):
381+
count = counts[k]
382+
if count == max_count:
383+
j += 1
384+
elif count > max_count:
385+
max_count = count
386+
j = 0
387+
else:
388+
continue
389+
390+
modes[j] = keys[k]
449391
else:
450-
raise TypeError(values.dtype)
392+
for k in range(nkeys):
393+
count = counts[k]
394+
if count == max_count:
395+
j += 1
396+
elif count > max_count:
397+
max_count = count
398+
j = 0
399+
else:
400+
continue
401+
402+
modes[j] = keys[k]
403+
404+
return modes[:j + 1]
451405

452406

453407
{{py:

0 commit comments

Comments
 (0)