Skip to content

ENH: cythonize groupby.count #7016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 5, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ Improvements to existing features
in item handling (:issue:`6745`, :issue:`6988`).
- Improve performance in certain reindexing operations by optimizing ``take_2d`` (:issue:`6749`)
- Arrays of strings can be wrapped to a specified width (``str.wrap``) (:issue:`6999`)
- ``GroupBy.count()`` is now implemented in Cython and is much faster for large
numbers of groups (:issue:`7016`).

.. _release.bug_fixes-0.14.0:

Expand Down
2 changes: 2 additions & 0 deletions doc/source/v0.14.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,8 @@ Performance
- Performance improvements in timedelta conversions for integer dtypes (:issue:`6754`)
- Improved performance of compatible pickles (:issue:`6899`)
- Improve performance in certain reindexing operations by optimizing ``take_2d`` (:issue:`6749`)
- ``GroupBy.count()`` is now implemented in Cython and is much faster for large
numbers of groups (:issue:`7016`).

Experimental
~~~~~~~~~~~~
Expand Down
40 changes: 21 additions & 19 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import collections

from pandas.compat import(
zip, builtins, range, long, lrange, lzip,
zip, builtins, range, long, lzip,
OrderedDict, callable
)
from pandas import compat
Expand Down Expand Up @@ -713,15 +713,6 @@ def size(self):
"""
return self.grouper.size()

def count(self, axis=0):
"""
Number of non-null items in each group.
axis : axis number, default 0
the grouping axis
"""
self._set_selection_from_grouper()
return self._python_agg_general(lambda x: notnull(x).sum(axis=axis)).astype('int64')

sum = _groupby_function('sum', 'add', np.sum)
prod = _groupby_function('prod', 'prod', np.prod)
min = _groupby_function('min', 'min', np.min, numeric_only=False)
Expand All @@ -731,6 +722,12 @@ def count(self, axis=0):
last = _groupby_function('last', 'last', _last_compat, numeric_only=False,
_convert=True)

_count = _groupby_function('_count', 'count',
lambda x, axis=0: notnull(x).sum(axis=axis),
numeric_only=False)

def count(self, axis=0):
return self._count().astype('int64')

def ohlc(self):
"""
Expand Down Expand Up @@ -1318,10 +1315,11 @@ def get_group_levels(self):
'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
},
'last': 'group_last',
'count': 'group_count',
}

_cython_transforms = {
'std': np.sqrt
'std': np.sqrt,
}

_cython_arity = {
Expand Down Expand Up @@ -1390,25 +1388,27 @@ def aggregate(self, values, how, axis=0):
values = com.ensure_float(values)
is_numeric = True
else:
if issubclass(values.dtype.type, np.datetime64):
raise Exception('Cython not able to handle this case')

values = values.astype(object)
is_numeric = False
is_numeric = issubclass(values.dtype.type, (np.datetime64,
np.timedelta64))
if is_numeric:
values = values.view('int64')
else:
values = values.astype(object)

# will be filled in Cython function
result = np.empty(out_shape, dtype=values.dtype)
result = np.empty(out_shape,
dtype=np.dtype('f%d' % values.dtype.itemsize))
result.fill(np.nan)
counts = np.zeros(self.ngroups, dtype=np.int64)

result = self._aggregate(result, counts, values, how, is_numeric)

if self._filter_empty_groups:
if result.ndim == 2:
if is_numeric:
try:
result = lib.row_bool_subset(
result, (counts > 0).view(np.uint8))
else:
except ValueError:
result = lib.row_bool_subset_object(
result, (counts > 0).view(np.uint8))
else:
Expand Down Expand Up @@ -1442,6 +1442,7 @@ def _aggregate(self, result, counts, values, how, is_numeric):
chunk = chunk.squeeze()
agg_func(result[:, :, i], counts, chunk, comp_ids)
else:
#import ipdb; ipdb.set_trace() # XXX BREAKPOINT
agg_func(result, counts, values, comp_ids)

return trans_func(result)
Expand Down Expand Up @@ -1651,6 +1652,7 @@ def names(self):
'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
},
'last': 'group_last_bin',
'count': 'group_count_bin',
}

_name_functions = {
Expand Down
93 changes: 89 additions & 4 deletions pandas/src/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# don't introduce a pandas/pandas.compat import
# or we get a bootstrapping problem
from StringIO import StringIO
import os

header = """
cimport numpy as np
Expand Down Expand Up @@ -34,7 +33,9 @@
ctypedef unsigned char UChar

cimport util
from util cimport is_array, _checknull, _checknan
from util cimport is_array, _checknull, _checknan, get_nat

cdef int64_t iNaT = get_nat()

# import datetime C API
PyDateTime_IMPORT
Expand Down Expand Up @@ -1150,6 +1151,79 @@ def group_var_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
(ct * ct - ct))
"""

group_count_template = """@cython.boundscheck(False)
@cython.wraparound(False)
def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
ndarray[int64_t] counts,
ndarray[%(c_type)s, ndim=2] values,
ndarray[int64_t] labels):
'''
Only aggregates on axis=0
'''
cdef:
Py_ssize_t i, j, lab
Py_ssize_t N = values.shape[0], K = values.shape[1]
%(c_type)s val
ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
dtype=np.int64)

if len(values) != len(labels):
raise AssertionError("len(index) != len(labels)")

for i in range(N):
lab = labels[i]
if lab < 0:
continue

counts[lab] += 1
for j in range(K):
val = values[i, j]

# not nan
nobs[lab, j] += val == val and val != iNaT

for i in range(len(counts)):
for j in range(K):
out[i, j] = nobs[i, j]


"""

group_count_bin_template = """@cython.boundscheck(False)
@cython.wraparound(False)
def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
ndarray[int64_t] counts,
ndarray[%(c_type)s, ndim=2] values,
ndarray[int64_t] bins):
'''
Only aggregates on axis=0
'''
cdef:
Py_ssize_t i, j, ngroups
Py_ssize_t N = values.shape[0], K = values.shape[1], b = 0
%(c_type)s val
ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
dtype=np.int64)

ngroups = len(bins) + (bins[len(bins) - 1] != N)

for i in range(N):
while b < ngroups - 1 and i >= bins[b]:
b += 1

counts[b] += 1
for j in range(K):
val = values[i, j]

# not nan
nobs[b, j] += val == val and val != iNaT

for i in range(ngroups):
for j in range(K):
out[i, j] = nobs[i, j]


"""
# add passing bin edges, instead of labels


Expand Down Expand Up @@ -2145,7 +2219,8 @@ def put2d_%(name)s_%(dest_type)s(ndarray[%(c_type)s, ndim=2, cast=True] values,
#-------------------------------------------------------------------------
# Generators

def generate_put_template(template, use_ints = True, use_floats = True):
def generate_put_template(template, use_ints = True, use_floats = True,
use_objects=False):
floats_list = [
('float64', 'float64_t', 'float64_t', 'np.float64'),
('float32', 'float32_t', 'float32_t', 'np.float32'),
Expand All @@ -2156,11 +2231,14 @@ def generate_put_template(template, use_ints = True, use_floats = True):
('int32', 'int32_t', 'float64_t', 'np.float64'),
('int64', 'int64_t', 'float64_t', 'np.float64'),
]
object_list = [('object', 'object', 'float64_t', 'np.float64')]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should float really be here? (as that should be the use_floats)? no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes float should be here use_(float|int)s is really only for the first two arguments which refer to the name and the ctype, but the result of a groupby operation is always a float (even count, which is astyped to int64 as the last method call before returning to the user).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh..right..np

function_list = []
if use_floats:
function_list.extend(floats_list)
if use_ints:
function_list.extend(ints_list)
if use_objects:
function_list.extend(object_list)

output = StringIO()
for name, c_type, dest_type, dest_dtype in function_list:
Expand Down Expand Up @@ -2251,6 +2329,8 @@ def generate_from_template(template, exclude=None):
group_max_bin_template,
group_ohlc_template]

groupby_count = [group_count_template, group_count_bin_template]

templates_1d = [map_indices_template,
pad_template,
backfill_template,
Expand All @@ -2272,6 +2352,7 @@ def generate_from_template(template, exclude=None):
take_2d_axis1_template,
take_2d_multi_template]


def generate_take_cython_file(path='generated.pyx'):
with open(path, 'w') as f:
print(header, file=f)
Expand All @@ -2288,7 +2369,10 @@ def generate_take_cython_file(path='generated.pyx'):
print(generate_put_template(template), file=f)

for template in groupbys:
print(generate_put_template(template, use_ints = False), file=f)
print(generate_put_template(template, use_ints=False), file=f)

for template in groupby_count:
print(generate_put_template(template, use_objects=True), file=f)

# for template in templates_1d_datetime:
# print >> f, generate_from_template_datetime(template)
Expand All @@ -2299,5 +2383,6 @@ def generate_take_cython_file(path='generated.pyx'):
for template in nobool_1d_templates:
print(generate_from_template(template, exclude=['bool']), file=f)


if __name__ == '__main__':
generate_take_cython_file()
Loading