Skip to content

Commit 4a27697

Browse files
WillAydjreback
authored andcommitted
Cythonized GroupBy any (pandas-dev#19722)
1 parent 96b8bb1 commit 4a27697

File tree

6 files changed

+222
-13
lines changed

6 files changed

+222
-13
lines changed

asv_bench/benchmarks/groupby.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from .pandas_vb_common import setup # noqa
1212

1313

14+
method_blacklist = {
15+
'object': {'median', 'prod', 'sem', 'cumsum', 'sum', 'cummin', 'mean',
16+
'max', 'skew', 'cumprod', 'cummax', 'rank', 'pct_change', 'min',
17+
'var', 'mad', 'describe', 'std'}
18+
}
19+
20+
1421
class ApplyDictReturn(object):
1522
goal_time = 0.2
1623

@@ -153,6 +160,7 @@ def time_frame_nth_any(self, df):
153160
def time_frame_nth(self, df):
154161
df.groupby(0).nth(0)
155162

163+
156164
def time_series_nth_any(self, df):
157165
df[1].groupby(df[0]).nth(0, dropna='any')
158166

@@ -369,23 +377,27 @@ class GroupByMethods(object):
369377
goal_time = 0.2
370378

371379
param_names = ['dtype', 'method']
372-
params = [['int', 'float'],
380+
params = [['int', 'float', 'object'],
373381
['all', 'any', 'bfill', 'count', 'cumcount', 'cummax', 'cummin',
374382
'cumprod', 'cumsum', 'describe', 'ffill', 'first', 'head',
375383
'last', 'mad', 'max', 'min', 'median', 'mean', 'nunique',
376384
'pct_change', 'prod', 'rank', 'sem', 'shift', 'size', 'skew',
377385
'std', 'sum', 'tail', 'unique', 'value_counts', 'var']]
378386

379387
def setup(self, dtype, method):
388+
if method in method_blacklist.get(dtype, {}):
389+
raise NotImplementedError # skip benchmark
380390
ngroups = 1000
381391
size = ngroups * 2
382392
rng = np.arange(ngroups)
383393
values = rng.take(np.random.randint(0, ngroups, size=size))
384394
if dtype == 'int':
385395
key = np.random.randint(0, size, size=size)
386-
else:
396+
elif dtype == 'float':
387397
key = np.concatenate([np.random.random(ngroups) * 0.1,
388398
np.random.random(ngroups) * 10.0])
399+
elif dtype == 'object':
400+
key = ['foo'] * size
389401

390402
df = DataFrame({'values': values, 'key': key})
391403
self.df_groupby_method = getattr(df.groupby('key')['values'], method)

doc/source/api.rst

+5
Original file line numberDiff line numberDiff line change
@@ -2179,8 +2179,12 @@ Computations / Descriptive Stats
21792179
.. autosummary::
21802180
:toctree: generated/
21812181

2182+
GroupBy.all
2183+
GroupBy.any
2184+
GroupBy.bfill
21822185
GroupBy.count
21832186
GroupBy.cumcount
2187+
GroupBy.ffill
21842188
GroupBy.first
21852189
GroupBy.head
21862190
GroupBy.last
@@ -2192,6 +2196,7 @@ Computations / Descriptive Stats
21922196
GroupBy.nth
21932197
GroupBy.ohlc
21942198
GroupBy.prod
2199+
GroupBy.rank
21952200
GroupBy.size
21962201
GroupBy.sem
21972202
GroupBy.std

doc/source/whatsnew/v0.23.0.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -729,9 +729,10 @@ Performance Improvements
729729
- Improved performance of :func:`DataFrame.median` with ``axis=1`` when bottleneck is not installed (:issue:`16468`)
730730
- Improved performance of :func:`MultiIndex.get_loc` for large indexes, at the cost of a reduction in performance for small ones (:issue:`18519`)
731731
- Improved performance of pairwise ``.rolling()`` and ``.expanding()`` with ``.cov()`` and ``.corr()`` operations (:issue:`17917`)
732-
- Improved performance of :func:`DataFrameGroupBy.rank` (:issue:`15779`)
732+
- Improved performance of :func:`pandas.core.groupby.GroupBy.rank` (:issue:`15779`)
733733
- Improved performance of variable ``.rolling()`` on ``.min()`` and ``.max()`` (:issue:`19521`)
734-
- Improved performance of ``GroupBy.ffill`` and ``GroupBy.bfill`` (:issue:`11296`)
734+
- Improved performance of :func:`pandas.core.groupby.GroupBy.ffill` and :func:`pandas.core.groupby.GroupBy.bfill` (:issue:`11296`)
735+
- Improved performance of :func:`pandas.core.groupby.GroupBy.any` and :func:`pandas.core.groupby.GroupBy.all` (:issue:`15435`)
735736

736737
.. _whatsnew_0230.docs:
737738

pandas/_libs/groupby.pyx

+57
Original file line numberDiff line numberDiff line change
@@ -310,5 +310,62 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[int64_t] labels,
310310
filled_vals = 0
311311

312312

313+
@cython.boundscheck(False)
314+
@cython.wraparound(False)
315+
def group_any_all(ndarray[uint8_t] out,
316+
ndarray[int64_t] labels,
317+
ndarray[uint8_t] values,
318+
ndarray[uint8_t] mask,
319+
object val_test,
320+
bint skipna):
321+
"""Aggregated boolean values to show truthfulness of group elements
322+
323+
Parameters
324+
----------
325+
out : array of values which this method will write its results to
326+
labels : array containing unique label for each group, with its
327+
ordering matching up to the corresponding record in `values`
328+
values : array containing the truth value of each element
329+
mask : array indicating whether a value is na or not
330+
val_test : str {'any', 'all'}
331+
String object dictating whether to use any or all truth testing
332+
skipna : boolean
333+
Flag to ignore nan values during truth testing
334+
335+
Notes
336+
-----
337+
This method modifies the `out` parameter rather than returning an object.
338+
The returned values will either be 0 or 1 (False or True, respectively).
339+
"""
340+
cdef:
341+
Py_ssize_t i, N=len(labels)
342+
int64_t lab
343+
uint8_t flag_val
344+
345+
if val_test == 'all':
346+
# Because the 'all' value of an empty iterable in Python is True we can
347+
# start with an array full of ones and set to zero when a False value
348+
# is encountered
349+
flag_val = 0
350+
elif val_test == 'any':
351+
# Because the 'any' value of an empty iterable in Python is False we
352+
# can start with an array full of zeros and set to one only if any
353+
# value encountered is True
354+
flag_val = 1
355+
else:
356+
raise ValueError("'bool_func' must be either 'any' or 'all'!")
357+
358+
out.fill(1 - flag_val)
359+
360+
with nogil:
361+
for i in range(N):
362+
lab = labels[i]
363+
if lab < 0 or (skipna and mask[i]):
364+
continue
365+
366+
if values[i] == flag_val:
367+
out[lab] = flag_val
368+
369+
313370
# generated from template
314371
include "groupby_helper.pxi"

pandas/core/groupby.py

+118-9
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,53 @@ class GroupBy(_GroupBy):
12191219
"""
12201220
_apply_whitelist = _common_apply_whitelist
12211221

1222+
def _bool_agg(self, val_test, skipna):
1223+
"""Shared func to call any / all Cython GroupBy implementations"""
1224+
1225+
def objs_to_bool(vals):
1226+
try:
1227+
vals = vals.astype(np.bool)
1228+
except ValueError: # for objects
1229+
vals = np.array([bool(x) for x in vals])
1230+
1231+
return vals.view(np.uint8)
1232+
1233+
def result_to_bool(result):
1234+
return result.astype(np.bool, copy=False)
1235+
1236+
return self._get_cythonized_result('group_any_all', self.grouper,
1237+
aggregate=True,
1238+
cython_dtype=np.uint8,
1239+
needs_values=True,
1240+
needs_mask=True,
1241+
pre_processing=objs_to_bool,
1242+
post_processing=result_to_bool,
1243+
val_test=val_test, skipna=skipna)
1244+
1245+
@Substitution(name='groupby')
1246+
@Appender(_doc_template)
1247+
def any(self, skipna=True):
1248+
"""Returns True if any value in the group is truthful, else False
1249+
1250+
Parameters
1251+
----------
1252+
skipna : bool, default True
1253+
Flag to ignore nan values during truth testing
1254+
"""
1255+
return self._bool_agg('any', skipna)
1256+
1257+
@Substitution(name='groupby')
1258+
@Appender(_doc_template)
1259+
def all(self, skipna=True):
1260+
"""Returns True if all values in the group are truthful, else False
1261+
1262+
Parameters
1263+
----------
1264+
skipna : bool, default True
1265+
Flag to ignore nan values during truth testing
1266+
"""
1267+
return self._bool_agg('all', skipna)
1268+
12221269
@Substitution(name='groupby')
12231270
@Appender(_doc_template)
12241271
def count(self):
@@ -1485,6 +1532,8 @@ def _fill(self, direction, limit=None):
14851532

14861533
return self._get_cythonized_result('group_fillna_indexer',
14871534
self.grouper, needs_mask=True,
1535+
cython_dtype=np.int64,
1536+
result_is_index=True,
14881537
direction=direction, limit=limit)
14891538

14901539
@Substitution(name='groupby')
@@ -1873,33 +1922,81 @@ def cummax(self, axis=0, **kwargs):
18731922

18741923
return self._cython_transform('cummax', numeric_only=False)
18751924

1876-
def _get_cythonized_result(self, how, grouper, needs_mask=False,
1877-
needs_ngroups=False, **kwargs):
1925+
def _get_cythonized_result(self, how, grouper, aggregate=False,
1926+
cython_dtype=None, needs_values=False,
1927+
needs_mask=False, needs_ngroups=False,
1928+
result_is_index=False,
1929+
pre_processing=None, post_processing=None,
1930+
**kwargs):
18781931
"""Get result for Cythonized functions
18791932
18801933
Parameters
18811934
----------
18821935
how : str, Cythonized function name to be called
18831936
grouper : Grouper object containing pertinent group info
1937+
aggregate : bool, default False
1938+
Whether the result should be aggregated to match the number of
1939+
groups
1940+
cython_dtype : default None
1941+
Type of the array that will be modified by the Cython call. If
1942+
`None`, the type will be inferred from the values of each slice
1943+
needs_values : bool, default False
1944+
Whether the values should be a part of the Cython call
1945+
signature
18841946
needs_mask : bool, default False
1885-
Whether boolean mask needs to be part of the Cython call signature
1947+
Whether boolean mask needs to be part of the Cython call
1948+
signature
18861949
needs_ngroups : bool, default False
1887-
Whether number of groups part of the Cython call signature
1950+
Whether number of groups is part of the Cython call signature
1951+
result_is_index : bool, default False
1952+
Whether the result of the Cython operation is an index of
1953+
values to be retrieved, instead of the actual values themselves
1954+
pre_processing : function, default None
1955+
Function to be applied to `values` prior to passing to Cython
1956+
Raises if `needs_values` is False
1957+
post_processing : function, default None
1958+
Function to be applied to result of Cython function
18881959
**kwargs : dict
18891960
Extra arguments to be passed back to Cython funcs
18901961
18911962
Returns
18921963
-------
18931964
`Series` or `DataFrame` with filled values
18941965
"""
1966+
if result_is_index and aggregate:
1967+
raise ValueError("'result_is_index' and 'aggregate' cannot both "
1968+
"be True!")
1969+
if post_processing:
1970+
if not callable(pre_processing):
1971+
raise ValueError("'post_processing' must be a callable!")
1972+
if pre_processing:
1973+
if not callable(pre_processing):
1974+
raise ValueError("'pre_processing' must be a callable!")
1975+
if not needs_values:
1976+
raise ValueError("Cannot use 'pre_processing' without "
1977+
"specifying 'needs_values'!")
18951978

18961979
labels, _, ngroups = grouper.group_info
18971980
output = collections.OrderedDict()
18981981
base_func = getattr(libgroupby, how)
18991982

19001983
for name, obj in self._iterate_slices():
1901-
indexer = np.zeros_like(labels, dtype=np.int64)
1902-
func = partial(base_func, indexer, labels)
1984+
if aggregate:
1985+
result_sz = ngroups
1986+
else:
1987+
result_sz = len(obj.values)
1988+
1989+
if not cython_dtype:
1990+
cython_dtype = obj.values.dtype
1991+
1992+
result = np.zeros(result_sz, dtype=cython_dtype)
1993+
func = partial(base_func, result, labels)
1994+
if needs_values:
1995+
vals = obj.values
1996+
if pre_processing:
1997+
vals = pre_processing(vals)
1998+
func = partial(func, vals)
1999+
19032000
if needs_mask:
19042001
mask = isnull(obj.values).view(np.uint8)
19052002
func = partial(func, mask)
@@ -1908,9 +2005,19 @@ def _get_cythonized_result(self, how, grouper, needs_mask=False,
19082005
func = partial(func, ngroups)
19092006

19102007
func(**kwargs) # Call func to modify indexer values in place
1911-
output[name] = algorithms.take_nd(obj.values, indexer)
19122008

1913-
return self._wrap_transformed_output(output)
2009+
if result_is_index:
2010+
result = algorithms.take_nd(obj.values, result)
2011+
2012+
if post_processing:
2013+
result = post_processing(result)
2014+
2015+
output[name] = result
2016+
2017+
if aggregate:
2018+
return self._wrap_aggregated_output(output)
2019+
else:
2020+
return self._wrap_transformed_output(output)
19142021

19152022
@Substitution(name='groupby')
19162023
@Appender(_doc_template)
@@ -1930,7 +2037,9 @@ def shift(self, periods=1, freq=None, axis=0):
19302037
return self.apply(lambda x: x.shift(periods, freq, axis))
19312038

19322039
return self._get_cythonized_result('group_shift_indexer',
1933-
self.grouper, needs_ngroups=True,
2040+
self.grouper, cython_dtype=np.int64,
2041+
needs_ngroups=True,
2042+
result_is_index=True,
19342043
periods=periods)
19352044

19362045
@Substitution(name='groupby')

pandas/tests/groupby/test_groupby.py

+25
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas import (date_range, bdate_range, Timestamp,
1010
Index, MultiIndex, DataFrame, Series,
1111
concat, Panel, DatetimeIndex, read_csv)
12+
from pandas.core.dtypes.missing import isna
1213
from pandas.errors import UnsupportedFunctionCall, PerformanceWarning
1314
from pandas.util.testing import (assert_frame_equal, assert_index_equal,
1415
assert_series_equal, assert_almost_equal)
@@ -2116,6 +2117,30 @@ def interweave(list_obj):
21162117
exp = DataFrame({'key': keys, 'val': _exp_vals})
21172118
assert_frame_equal(result, exp)
21182119

2120+
@pytest.mark.parametrize("agg_func", ['any', 'all'])
2121+
@pytest.mark.parametrize("skipna", [True, False])
2122+
@pytest.mark.parametrize("vals", [
2123+
['foo', 'bar', 'baz'], ['foo', '', ''], ['', '', ''],
2124+
[1, 2, 3], [1, 0, 0], [0, 0, 0],
2125+
[1., 2., 3.], [1., 0., 0.], [0., 0., 0.],
2126+
[True, True, True], [True, False, False], [False, False, False],
2127+
[np.nan, np.nan, np.nan]
2128+
])
2129+
def test_groupby_bool_aggs(self, agg_func, skipna, vals):
2130+
df = DataFrame({'key': ['a'] * 3 + ['b'] * 3, 'val': vals * 2})
2131+
2132+
# Figure out expectation using Python builtin
2133+
exp = getattr(compat.builtins, agg_func)(vals)
2134+
2135+
# edge case for missing data with skipna and 'any'
2136+
if skipna and all(isna(vals)) and agg_func == 'any':
2137+
exp = False
2138+
2139+
exp_df = DataFrame([exp] * 2, columns=['val'], index=pd.Index(
2140+
['a', 'b'], name='key'))
2141+
result = getattr(df.groupby('key'), agg_func)(skipna=skipna)
2142+
assert_frame_equal(result, exp_df)
2143+
21192144
def test_dont_clobber_name_column(self):
21202145
df = DataFrame({'key': ['a', 'a', 'a', 'b', 'b', 'b'],
21212146
'name': ['foo', 'bar', 'baz'] * 2})

0 commit comments

Comments
 (0)