Skip to content

Commit eb3b677

Browse files
committed
Merge pull request #7016 from cpcloud/groupby-count-cython
ENH: cythonize groupby.count
2 parents 16e70c8 + e82a65a commit eb3b677

File tree

8 files changed

+693
-38
lines changed

8 files changed

+693
-38
lines changed

doc/source/release.rst

+2
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ Improvements to existing features
313313
in item handling (:issue:`6745`, :issue:`6988`).
314314
- Improve performance in certain reindexing operations by optimizing ``take_2d`` (:issue:`6749`)
315315
- Arrays of strings can be wrapped to a specified width (``str.wrap``) (:issue:`6999`)
316+
- ``GroupBy.count()`` is now implemented in Cython and is much faster for large
317+
numbers of groups (:issue:`7016`).
316318

317319
.. _release.bug_fixes-0.14.0:
318320

doc/source/v0.14.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ Performance
568568
- Performance improvements in timedelta conversions for integer dtypes (:issue:`6754`)
569569
- Improved performance of compatible pickles (:issue:`6899`)
570570
- Improve performance in certain reindexing operations by optimizing ``take_2d`` (:issue:`6749`)
571+
- ``GroupBy.count()`` is now implemented in Cython and is much faster for large
572+
numbers of groups (:issue:`7016`).
571573

572574
Experimental
573575
~~~~~~~~~~~~

pandas/core/groupby.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import collections
66

77
from pandas.compat import(
8-
zip, builtins, range, long, lrange, lzip,
8+
zip, builtins, range, long, lzip,
99
OrderedDict, callable
1010
)
1111
from pandas import compat
@@ -713,15 +713,6 @@ def size(self):
713713
"""
714714
return self.grouper.size()
715715

716-
def count(self, axis=0):
717-
"""
718-
Number of non-null items in each group.
719-
axis : axis number, default 0
720-
the grouping axis
721-
"""
722-
self._set_selection_from_grouper()
723-
return self._python_agg_general(lambda x: notnull(x).sum(axis=axis)).astype('int64')
724-
725716
sum = _groupby_function('sum', 'add', np.sum)
726717
prod = _groupby_function('prod', 'prod', np.prod)
727718
min = _groupby_function('min', 'min', np.min, numeric_only=False)
@@ -731,6 +722,12 @@ def count(self, axis=0):
731722
last = _groupby_function('last', 'last', _last_compat, numeric_only=False,
732723
_convert=True)
733724

725+
_count = _groupby_function('_count', 'count',
726+
lambda x, axis=0: notnull(x).sum(axis=axis),
727+
numeric_only=False)
728+
729+
def count(self, axis=0):
730+
return self._count().astype('int64')
734731

735732
def ohlc(self):
736733
"""
@@ -1318,10 +1315,11 @@ def get_group_levels(self):
13181315
'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
13191316
},
13201317
'last': 'group_last',
1318+
'count': 'group_count',
13211319
}
13221320

13231321
_cython_transforms = {
1324-
'std': np.sqrt
1322+
'std': np.sqrt,
13251323
}
13261324

13271325
_cython_arity = {
@@ -1390,25 +1388,27 @@ def aggregate(self, values, how, axis=0):
13901388
values = com.ensure_float(values)
13911389
is_numeric = True
13921390
else:
1393-
if issubclass(values.dtype.type, np.datetime64):
1394-
raise Exception('Cython not able to handle this case')
1395-
1396-
values = values.astype(object)
1397-
is_numeric = False
1391+
is_numeric = issubclass(values.dtype.type, (np.datetime64,
1392+
np.timedelta64))
1393+
if is_numeric:
1394+
values = values.view('int64')
1395+
else:
1396+
values = values.astype(object)
13981397

13991398
# will be filled in Cython function
1400-
result = np.empty(out_shape, dtype=values.dtype)
1399+
result = np.empty(out_shape,
1400+
dtype=np.dtype('f%d' % values.dtype.itemsize))
14011401
result.fill(np.nan)
14021402
counts = np.zeros(self.ngroups, dtype=np.int64)
14031403

14041404
result = self._aggregate(result, counts, values, how, is_numeric)
14051405

14061406
if self._filter_empty_groups:
14071407
if result.ndim == 2:
1408-
if is_numeric:
1408+
try:
14091409
result = lib.row_bool_subset(
14101410
result, (counts > 0).view(np.uint8))
1411-
else:
1411+
except ValueError:
14121412
result = lib.row_bool_subset_object(
14131413
result, (counts > 0).view(np.uint8))
14141414
else:
@@ -1442,6 +1442,7 @@ def _aggregate(self, result, counts, values, how, is_numeric):
14421442
chunk = chunk.squeeze()
14431443
agg_func(result[:, :, i], counts, chunk, comp_ids)
14441444
else:
1445+
#import ipdb; ipdb.set_trace() # XXX BREAKPOINT
14451446
agg_func(result, counts, values, comp_ids)
14461447

14471448
return trans_func(result)
@@ -1651,6 +1652,7 @@ def names(self):
16511652
'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
16521653
},
16531654
'last': 'group_last_bin',
1655+
'count': 'group_count_bin',
16541656
}
16551657

16561658
_name_functions = {

pandas/src/generate_code.py

+89-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# don't introduce a pandas/pandas.compat import
44
# or we get a bootstrapping problem
55
from StringIO import StringIO
6-
import os
76

87
header = """
98
cimport numpy as np
@@ -34,7 +33,9 @@
3433
ctypedef unsigned char UChar
3534
3635
cimport util
37-
from util cimport is_array, _checknull, _checknan
36+
from util cimport is_array, _checknull, _checknan, get_nat
37+
38+
cdef int64_t iNaT = get_nat()
3839
3940
# import datetime C API
4041
PyDateTime_IMPORT
@@ -1150,6 +1151,79 @@ def group_var_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
11501151
(ct * ct - ct))
11511152
"""
11521153

1154+
group_count_template = """@cython.boundscheck(False)
1155+
@cython.wraparound(False)
1156+
def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1157+
ndarray[int64_t] counts,
1158+
ndarray[%(c_type)s, ndim=2] values,
1159+
ndarray[int64_t] labels):
1160+
'''
1161+
Only aggregates on axis=0
1162+
'''
1163+
cdef:
1164+
Py_ssize_t i, j, lab
1165+
Py_ssize_t N = values.shape[0], K = values.shape[1]
1166+
%(c_type)s val
1167+
ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1168+
dtype=np.int64)
1169+
1170+
if len(values) != len(labels):
1171+
raise AssertionError("len(index) != len(labels)")
1172+
1173+
for i in range(N):
1174+
lab = labels[i]
1175+
if lab < 0:
1176+
continue
1177+
1178+
counts[lab] += 1
1179+
for j in range(K):
1180+
val = values[i, j]
1181+
1182+
# not nan
1183+
nobs[lab, j] += val == val and val != iNaT
1184+
1185+
for i in range(len(counts)):
1186+
for j in range(K):
1187+
out[i, j] = nobs[i, j]
1188+
1189+
1190+
"""
1191+
1192+
group_count_bin_template = """@cython.boundscheck(False)
1193+
@cython.wraparound(False)
1194+
def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1195+
ndarray[int64_t] counts,
1196+
ndarray[%(c_type)s, ndim=2] values,
1197+
ndarray[int64_t] bins):
1198+
'''
1199+
Only aggregates on axis=0
1200+
'''
1201+
cdef:
1202+
Py_ssize_t i, j, ngroups
1203+
Py_ssize_t N = values.shape[0], K = values.shape[1], b = 0
1204+
%(c_type)s val
1205+
ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1206+
dtype=np.int64)
1207+
1208+
ngroups = len(bins) + (bins[len(bins) - 1] != N)
1209+
1210+
for i in range(N):
1211+
while b < ngroups - 1 and i >= bins[b]:
1212+
b += 1
1213+
1214+
counts[b] += 1
1215+
for j in range(K):
1216+
val = values[i, j]
1217+
1218+
# not nan
1219+
nobs[b, j] += val == val and val != iNaT
1220+
1221+
for i in range(ngroups):
1222+
for j in range(K):
1223+
out[i, j] = nobs[i, j]
1224+
1225+
1226+
"""
11531227
# add passing bin edges, instead of labels
11541228

11551229

@@ -2145,7 +2219,8 @@ def put2d_%(name)s_%(dest_type)s(ndarray[%(c_type)s, ndim=2, cast=True] values,
21452219
#-------------------------------------------------------------------------
21462220
# Generators
21472221

2148-
def generate_put_template(template, use_ints = True, use_floats = True):
2222+
def generate_put_template(template, use_ints = True, use_floats = True,
2223+
use_objects=False):
21492224
floats_list = [
21502225
('float64', 'float64_t', 'float64_t', 'np.float64'),
21512226
('float32', 'float32_t', 'float32_t', 'np.float32'),
@@ -2156,11 +2231,14 @@ def generate_put_template(template, use_ints = True, use_floats = True):
21562231
('int32', 'int32_t', 'float64_t', 'np.float64'),
21572232
('int64', 'int64_t', 'float64_t', 'np.float64'),
21582233
]
2234+
object_list = [('object', 'object', 'float64_t', 'np.float64')]
21592235
function_list = []
21602236
if use_floats:
21612237
function_list.extend(floats_list)
21622238
if use_ints:
21632239
function_list.extend(ints_list)
2240+
if use_objects:
2241+
function_list.extend(object_list)
21642242

21652243
output = StringIO()
21662244
for name, c_type, dest_type, dest_dtype in function_list:
@@ -2251,6 +2329,8 @@ def generate_from_template(template, exclude=None):
22512329
group_max_bin_template,
22522330
group_ohlc_template]
22532331

2332+
groupby_count = [group_count_template, group_count_bin_template]
2333+
22542334
templates_1d = [map_indices_template,
22552335
pad_template,
22562336
backfill_template,
@@ -2272,6 +2352,7 @@ def generate_from_template(template, exclude=None):
22722352
take_2d_axis1_template,
22732353
take_2d_multi_template]
22742354

2355+
22752356
def generate_take_cython_file(path='generated.pyx'):
22762357
with open(path, 'w') as f:
22772358
print(header, file=f)
@@ -2288,7 +2369,10 @@ def generate_take_cython_file(path='generated.pyx'):
22882369
print(generate_put_template(template), file=f)
22892370

22902371
for template in groupbys:
2291-
print(generate_put_template(template, use_ints = False), file=f)
2372+
print(generate_put_template(template, use_ints=False), file=f)
2373+
2374+
for template in groupby_count:
2375+
print(generate_put_template(template, use_objects=True), file=f)
22922376

22932377
# for template in templates_1d_datetime:
22942378
# print >> f, generate_from_template_datetime(template)
@@ -2299,5 +2383,6 @@ def generate_take_cython_file(path='generated.pyx'):
22992383
for template in nobool_1d_templates:
23002384
print(generate_from_template(template, exclude=['bool']), file=f)
23012385

2386+
23022387
if __name__ == '__main__':
23032388
generate_take_cython_file()

0 commit comments

Comments
 (0)