3
3
# don't introduce a pandas/pandas.compat import
4
4
# or we get a bootstrapping problem
5
5
from StringIO import StringIO
6
- import os
7
6
8
7
header = """
9
8
cimport numpy as np
34
33
ctypedef unsigned char UChar
35
34
36
35
cimport util
37
- from util cimport is_array, _checknull, _checknan
36
+ from util cimport is_array, _checknull, _checknan, get_nat
37
+
38
+ cdef int64_t iNaT = get_nat()
38
39
39
40
# import datetime C API
40
41
PyDateTime_IMPORT
@@ -1150,6 +1151,79 @@ def group_var_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1150
1151
(ct * ct - ct))
1151
1152
"""
1152
1153
1154
+ group_count_template = """@cython.boundscheck(False)
1155
+ @cython.wraparound(False)
1156
+ def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1157
+ ndarray[int64_t] counts,
1158
+ ndarray[%(c_type)s, ndim=2] values,
1159
+ ndarray[int64_t] labels):
1160
+ '''
1161
+ Only aggregates on axis=0
1162
+ '''
1163
+ cdef:
1164
+ Py_ssize_t i, j, lab
1165
+ Py_ssize_t N = values.shape[0], K = values.shape[1]
1166
+ %(c_type)s val
1167
+ ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1168
+ dtype=np.int64)
1169
+
1170
+ if len(values) != len(labels):
1171
+ raise AssertionError("len(index) != len(labels)")
1172
+
1173
+ for i in range(N):
1174
+ lab = labels[i]
1175
+ if lab < 0:
1176
+ continue
1177
+
1178
+ counts[lab] += 1
1179
+ for j in range(K):
1180
+ val = values[i, j]
1181
+
1182
+ # not nan
1183
+ nobs[lab, j] += val == val and val != iNaT
1184
+
1185
+ for i in range(len(counts)):
1186
+ for j in range(K):
1187
+ out[i, j] = nobs[i, j]
1188
+
1189
+
1190
+ """
1191
+
1192
+ group_count_bin_template = """@cython.boundscheck(False)
1193
+ @cython.wraparound(False)
1194
+ def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1195
+ ndarray[int64_t] counts,
1196
+ ndarray[%(c_type)s, ndim=2] values,
1197
+ ndarray[int64_t] bins):
1198
+ '''
1199
+ Only aggregates on axis=0
1200
+ '''
1201
+ cdef:
1202
+ Py_ssize_t i, j, ngroups
1203
+ Py_ssize_t N = values.shape[0], K = values.shape[1], b = 0
1204
+ %(c_type)s val
1205
+ ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1206
+ dtype=np.int64)
1207
+
1208
+ ngroups = len(bins) + (bins[len(bins) - 1] != N)
1209
+
1210
+ for i in range(N):
1211
+ while b < ngroups - 1 and i >= bins[b]:
1212
+ b += 1
1213
+
1214
+ counts[b] += 1
1215
+ for j in range(K):
1216
+ val = values[i, j]
1217
+
1218
+ # not nan
1219
+ nobs[b, j] += val == val and val != iNaT
1220
+
1221
+ for i in range(ngroups):
1222
+ for j in range(K):
1223
+ out[i, j] = nobs[i, j]
1224
+
1225
+
1226
+ """
1153
1227
# add passing bin edges, instead of labels
1154
1228
1155
1229
@@ -2145,7 +2219,8 @@ def put2d_%(name)s_%(dest_type)s(ndarray[%(c_type)s, ndim=2, cast=True] values,
2145
2219
#-------------------------------------------------------------------------
2146
2220
# Generators
2147
2221
2148
- def generate_put_template (template , use_ints = True , use_floats = True ):
2222
+ def generate_put_template (template , use_ints = True , use_floats = True ,
2223
+ use_objects = False ):
2149
2224
floats_list = [
2150
2225
('float64' , 'float64_t' , 'float64_t' , 'np.float64' ),
2151
2226
('float32' , 'float32_t' , 'float32_t' , 'np.float32' ),
@@ -2156,11 +2231,14 @@ def generate_put_template(template, use_ints = True, use_floats = True):
2156
2231
('int32' , 'int32_t' , 'float64_t' , 'np.float64' ),
2157
2232
('int64' , 'int64_t' , 'float64_t' , 'np.float64' ),
2158
2233
]
2234
+ object_list = [('object' , 'object' , 'float64_t' , 'np.float64' )]
2159
2235
function_list = []
2160
2236
if use_floats :
2161
2237
function_list .extend (floats_list )
2162
2238
if use_ints :
2163
2239
function_list .extend (ints_list )
2240
+ if use_objects :
2241
+ function_list .extend (object_list )
2164
2242
2165
2243
output = StringIO ()
2166
2244
for name , c_type , dest_type , dest_dtype in function_list :
@@ -2251,6 +2329,8 @@ def generate_from_template(template, exclude=None):
2251
2329
group_max_bin_template ,
2252
2330
group_ohlc_template ]
2253
2331
2332
+ groupby_count = [group_count_template , group_count_bin_template ]
2333
+
2254
2334
templates_1d = [map_indices_template ,
2255
2335
pad_template ,
2256
2336
backfill_template ,
@@ -2272,6 +2352,7 @@ def generate_from_template(template, exclude=None):
2272
2352
take_2d_axis1_template ,
2273
2353
take_2d_multi_template ]
2274
2354
2355
+
2275
2356
def generate_take_cython_file (path = 'generated.pyx' ):
2276
2357
with open (path , 'w' ) as f :
2277
2358
print (header , file = f )
@@ -2288,7 +2369,10 @@ def generate_take_cython_file(path='generated.pyx'):
2288
2369
print (generate_put_template (template ), file = f )
2289
2370
2290
2371
for template in groupbys :
2291
- print (generate_put_template (template , use_ints = False ), file = f )
2372
+ print (generate_put_template (template , use_ints = False ), file = f )
2373
+
2374
+ for template in groupby_count :
2375
+ print (generate_put_template (template , use_objects = True ), file = f )
2292
2376
2293
2377
# for template in templates_1d_datetime:
2294
2378
# print >> f, generate_from_template_datetime(template)
@@ -2299,5 +2383,6 @@ def generate_take_cython_file(path='generated.pyx'):
2299
2383
for template in nobool_1d_templates :
2300
2384
print (generate_from_template (template , exclude = ['bool' ]), file = f )
2301
2385
2386
+
2302
2387
if __name__ == '__main__' :
2303
2388
generate_take_cython_file ()
0 commit comments