Skip to content

Commit 0af6cea

Browse files
committed
Wired any/all into _get_cythonized_result
1 parent 7ff8e48 commit 0af6cea

File tree

2 files changed

+119
-35
lines changed

2 files changed

+119
-35
lines changed

pandas/_libs/groupby.pyx

+14-19
Original file line numberDiff line numberDiff line change
@@ -314,18 +314,20 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[int64_t] labels,
314314

315315
@cython.boundscheck(False)
316316
@cython.wraparound(False)
317-
def group_any(ndarray[int64_t] out,
318-
ndarray values,
317+
def group_any(ndarray[uint8_t] out,
319318
ndarray[int64_t] labels,
319+
ndarray[uint8_t] values,
320+
ndarray[uint8_t] mask,
320321
bint skipna):
321322
"""Aggregated boolean values to show if any group element is truthful
322323
323324
Parameters
324325
----------
325-
out : array of int64_t values which this method will write its results to
326-
values : array of values to be truth-tested
326+
out : array of values which this method will write its results to
327327
labels : array containing unique label for each group, with its ordering
328328
matching up to the corresponding record in `values`
329+
values : array containing the truth value of each element
330+
mask : array indicating whether a value is na or not
329331
skipna : boolean
330332
Flag to ignore nan values during truth testing
331333
@@ -337,40 +339,33 @@ def group_any(ndarray[int64_t] out,
337339
cdef:
338340
Py_ssize_t i, N=len(labels)
339341
int64_t lab
340-
ndarray[int64_t] bool_mask
341-
ndarray[uint8_t] isna_mask
342-
343-
if values.dtype == 'object':
344-
bool_mask = np.array([bool(x) for x in values]).astype(np.int64)
345-
isna_mask = missing.isnaobj(values).astype(np.uint8)
346-
else:
347-
bool_mask = values.astype(np.bool).astype(np.int64)
348-
isna_mask = np.isnan(values).astype(np.uint8)
349342

350343
with nogil:
351344
for i in range(N):
352345
lab = labels[i]
353-
if lab < 0 or (skipna and isna_mask[i]):
346+
if lab < 0 or (skipna and mask[i]):
354347
continue
355348

356-
if bool_mask[i]:
349+
if values[i]:
357350
out[lab] = 1
358351

359352

360353
@cython.boundscheck(False)
361354
@cython.wraparound(False)
362-
def group_all(ndarray[int64_t] out,
363-
ndarray values,
355+
def group_all(ndarray[uint8_t] out,
364356
ndarray[int64_t] labels,
357+
ndarray[uint8_t] values,
358+
ndarray[uint8_t] mask,
365359
bint skipna):
366360
"""Aggregated boolean values to show if all group elements are truthful
367361
368362
Parameters
369363
----------
370-
out : array of int64_t values which this method will write its results to
371-
values : array of values to be truth-tested
364+
out : array of values which this method will write its results to
372365
labels : array containing unique label for each group, with its ordering
373366
matching up to the corresponding record in `values`
367+
values : array containing the truth value of each element
368+
mask : array indicating whether a value is na or not
374369
skipna : boolean
375370
Flag to ignore nan values during truth testing
376371

pandas/core/groupby.py

+105-16
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,29 @@ class GroupBy(_GroupBy):
12191219
"""
12201220
_apply_whitelist = _common_apply_whitelist
12211221

1222+
def _bool_agg(self, how, skipna):
1223+
"""Shared func to call any / all Cython GroupBy implementations"""
1224+
1225+
def objs_to_bool(vals):
1226+
try:
1227+
vals = vals.astype(np.bool)
1228+
except ValueError: # for objects
1229+
vals = np.array([bool(x) for x in vals])
1230+
1231+
return vals.view(np.uint8)
1232+
1233+
def result_to_bool(result):
1234+
return result.astype(np.bool, copy=False)
1235+
1236+
return self._get_cythonized_result(how, self.grouper,
1237+
aggregate=True,
1238+
cython_dtype=np.uint8,
1239+
needs_values=True,
1240+
needs_mask=True,
1241+
pre_processing=objs_to_bool,
1242+
post_processing=result_to_bool,
1243+
skipna=skipna)
1244+
12221245
@Substitution(name='groupby')
12231246
@Appender(_doc_template)
12241247
def any(self, skipna=True):
@@ -1229,15 +1252,19 @@ def any(self, skipna=True):
12291252
skipna : bool, default True
12301253
Flag to ignore nan values during truth testing
12311254
"""
1232-
labels, _, _ = self.grouper.group_info
1233-
output = collections.OrderedDict()
1255+
return self._bool_agg('group_any', skipna)
12341256

1235-
for name, obj in self._iterate_slices():
1236-
result = np.zeros(self.ngroups, dtype=np.int64)
1237-
libgroupby.group_any(result, obj.values, labels, skipna)
1238-
output[name] = result.astype(np.bool)
1257+
@Substitution(name='groupby')
1258+
@Appender(_doc_template)
1259+
def all(self, skipna=True):
1260+
"""Returns True if all values in the group are truthful, else False
12391261
1240-
return self._wrap_aggregated_output(output)
1262+
Parameters
1263+
----------
1264+
skipna : bool, default True
1265+
Flag to ignore nan values during truth testing
1266+
"""
1267+
return self._bool_agg('group_all', skipna)
12411268

12421269
@Substitution(name='groupby')
12431270
@Appender(_doc_template)
@@ -1505,6 +1532,8 @@ def _fill(self, direction, limit=None):
15051532

15061533
return self._get_cythonized_result('group_fillna_indexer',
15071534
self.grouper, needs_mask=True,
1535+
cython_dtype=np.int64,
1536+
result_is_index=True,
15081537
direction=direction, limit=limit)
15091538

15101539
@Substitution(name='groupby')
@@ -1893,33 +1922,81 @@ def cummax(self, axis=0, **kwargs):
18931922

18941923
return self._cython_transform('cummax', numeric_only=False)
18951924

1896-
def _get_cythonized_result(self, how, grouper, needs_mask=False,
1897-
needs_ngroups=False, **kwargs):
1925+
def _get_cythonized_result(self, how, grouper, aggregate=False,
1926+
cython_dtype=None, needs_values=False,
1927+
needs_mask=False, needs_ngroups=False,
1928+
result_is_index=False,
1929+
pre_processing=None, post_processing=None,
1930+
**kwargs):
18981931
"""Get result for Cythonized functions
18991932
19001933
Parameters
19011934
----------
19021935
how : str, Cythonized function name to be called
19031936
grouper : Grouper object containing pertinent group info
1937+
aggregate : bool, default False
1938+
Whether the result should be aggregated to match the number of
1939+
groups
1940+
cython_dtype : default None
1941+
Type of the array that will be modified by the Cython call. If
1942+
`None`, the type will be inferred from the values of each slice
1943+
needs_values : bool, default False
1944+
Whether the values should be a part of the Cython call
1945+
signature
19041946
needs_mask : bool, default False
1905-
Whether boolean mask needs to be part of the Cython call signature
1947+
Whether boolean mask needs to be part of the Cython call
1948+
signature
19061949
needs_ngroups : bool, default False
1907-
Whether number of groups part of the Cython call signature
1950+
Whether number of groups is part of the Cython call signature
1951+
result_is_index : bool, default False
1952+
Whether the result of the Cython operation is an index of
1953+
values to be retrieved, instead of the actual values themselves
1954+
pre_processing : function, default None
1955+
Function to be applied to `values` prior to passing to Cython
1956+
Raises if `needs_values` is False
1957+
post_processing : function, default None
1958+
Function to be applied to result of Cython function
19081959
**kwargs : dict
19091960
Extra arguments to be passed back to Cython funcs
19101961
19111962
Returns
19121963
-------
19131964
`Series` or `DataFrame` with filled values
19141965
"""
1966+
if result_is_index and aggregate:
1967+
raise ValueError("'result_is_index' and 'aggregate' cannot both "
1968+
"be True!")
1969+
if post_processing:
1970+
if not callable(pre_processing):
1971+
raise ValueError("'post_processing' must be a callable!")
1972+
if pre_processing:
1973+
if not callable(pre_processing):
1974+
raise ValueError("'pre_processing' must be a callable!")
1975+
if not needs_values:
1976+
raise ValueError("Cannot use 'pre_processing' without "
1977+
"specifying 'needs_values'!")
19151978

19161979
labels, _, ngroups = grouper.group_info
19171980
output = collections.OrderedDict()
19181981
base_func = getattr(libgroupby, how)
19191982

19201983
for name, obj in self._iterate_slices():
1921-
indexer = np.zeros_like(labels, dtype=np.int64)
1922-
func = partial(base_func, indexer, labels)
1984+
if aggregate:
1985+
result_sz = ngroups
1986+
else:
1987+
result_sz = len(obj.values)
1988+
1989+
if not cython_dtype:
1990+
cython_dtype = obj.values.dtype
1991+
1992+
result = np.zeros(result_sz, dtype=cython_dtype)
1993+
func = partial(base_func, result, labels)
1994+
if needs_values:
1995+
vals = obj.values
1996+
if pre_processing:
1997+
vals = pre_processing(vals)
1998+
func = partial(func, vals)
1999+
19232000
if needs_mask:
19242001
mask = isnull(obj.values).view(np.uint8)
19252002
func = partial(func, mask)
@@ -1928,9 +2005,19 @@ def _get_cythonized_result(self, how, grouper, needs_mask=False,
19282005
func = partial(func, ngroups)
19292006

19302007
func(**kwargs) # Call func to modify indexer values in place
1931-
output[name] = algorithms.take_nd(obj.values, indexer)
19322008

1933-
return self._wrap_transformed_output(output)
2009+
if result_is_index:
2010+
result = algorithms.take_nd(obj.values, result)
2011+
2012+
if post_processing:
2013+
result = post_processing(result)
2014+
2015+
output[name] = result
2016+
2017+
if aggregate:
2018+
return self._wrap_aggregated_output(output)
2019+
else:
2020+
return self._wrap_transformed_output(output)
19342021

19352022
@Substitution(name='groupby')
19362023
@Appender(_doc_template)
@@ -1950,7 +2037,9 @@ def shift(self, periods=1, freq=None, axis=0):
19502037
return self.apply(lambda x: x.shift(periods, freq, axis))
19512038

19522039
return self._get_cythonized_result('group_shift_indexer',
1953-
self.grouper, needs_ngroups=True,
2040+
self.grouper, cython_dtype=np.int64,
2041+
needs_ngroups=True,
2042+
result_is_index=True,
19542043
periods=periods)
19552044

19562045
@Substitution(name='groupby')

0 commit comments

Comments
 (0)