Skip to content

Commit d523d9f

Browse files
jbrockmendeljreback
authored andcommitted
Use fused types for _take_2d (pandas-dev#22917)
1 parent ee27fab commit d523d9f

5 files changed

+81
-70
lines changed

pandas/_libs/algos_common_helper.pxi.in

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Template for each `dtype` helper function using 1-d template
33

44
# 1-d template
5-
- map_indices
65
- pad
76
- pad_1d
87
- pad_2d

pandas/_libs/algos_rank_helper.pxi.in

-9
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,8 @@ dtypes = [('object', 'object', 'Infinity()', 'NegInfinity()'),
2424

2525
@cython.wraparound(False)
2626
@cython.boundscheck(False)
27-
{{if dtype == 'object'}}
28-
29-
3027
def rank_1d_{{dtype}}(object in_arr, ties_method='average',
3128
ascending=True, na_option='keep', pct=False):
32-
{{else}}
33-
34-
35-
def rank_1d_{{dtype}}(object in_arr, ties_method='average', ascending=True,
36-
na_option='keep', pct=False):
37-
{{endif}}
3829
"""
3930
Fast NaN-friendly version of scipy.stats.rankdata
4031
"""

pandas/_libs/algos_take_helper.pxi.in

+21-15
Original file line numberDiff line numberDiff line change
@@ -260,33 +260,39 @@ def take_2d_multi_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
260260

261261
{{endfor}}
262262

263-
#----------------------------------------------------------------------
263+
# ----------------------------------------------------------------------
264264
# take_2d internal function
265-
#----------------------------------------------------------------------
265+
# ----------------------------------------------------------------------
266266

267-
{{py:
268-
269-
# dtype, ctype, init_result
270-
dtypes = [('float64', 'float64_t', 'np.empty_like(values)'),
271-
('uint64', 'uint64_t', 'np.empty_like(values)'),
272-
('object', 'object', 'values.copy()'),
273-
('int64', 'int64_t', 'np.empty_like(values)')]
274-
}}
267+
ctypedef fused take_t:
268+
float64_t
269+
uint64_t
270+
int64_t
271+
object
275272

276-
{{for dtype, ctype, init_result in dtypes}}
277273

278-
cdef _take_2d_{{dtype}}(ndarray[{{ctype}}, ndim=2] values, object idx):
274+
cdef _take_2d(ndarray[take_t, ndim=2] values, object idx):
279275
cdef:
280276
Py_ssize_t i, j, N, K
281277
ndarray[Py_ssize_t, ndim=2, cast=True] indexer = idx
282-
ndarray[{{ctype}}, ndim=2] result
278+
ndarray[take_t, ndim=2] result
283279
object val
284280

285281
N, K = (<object> values).shape
286-
result = {{init_result}}
282+
283+
if take_t is object:
284+
# evaluated at compile-time
285+
result = values.copy()
286+
else:
287+
result = np.empty_like(values)
288+
287289
for i in range(N):
288290
for j in range(K):
289291
result[i, j] = values[i, indexer[i, j]]
290292
return result
291293

292-
{{endfor}}
294+
295+
_take_2d_object = _take_2d[object]
296+
_take_2d_float64 = _take_2d[float64_t]
297+
_take_2d_int64 = _take_2d[int64_t]
298+
_take_2d_uint64 = _take_2d[uint64_t]

pandas/_libs/join_func_helper.pxi.in

+22-22
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,21 @@ def asof_join_backward_{{on_dtype}}_by_{{by_dtype}}(
6868

6969
# find last position in right whose value is less than left's
7070
if allow_exact_matches:
71-
while right_pos < right_size and\
72-
right_values[right_pos] <= left_values[left_pos]:
71+
while (right_pos < right_size and
72+
right_values[right_pos] <= left_values[left_pos]):
7373
hash_table.set_item(right_by_values[right_pos], right_pos)
7474
right_pos += 1
7575
else:
76-
while right_pos < right_size and\
77-
right_values[right_pos] < left_values[left_pos]:
76+
while (right_pos < right_size and
77+
right_values[right_pos] < left_values[left_pos]):
7878
hash_table.set_item(right_by_values[right_pos], right_pos)
7979
right_pos += 1
8080
right_pos -= 1
8181

8282
# save positions as the desired index
8383
by_value = left_by_values[left_pos]
84-
found_right_pos = hash_table.get_item(by_value)\
85-
if by_value in hash_table else -1
84+
found_right_pos = (hash_table.get_item(by_value)
85+
if by_value in hash_table else -1)
8686
left_indexer[left_pos] = left_pos
8787
right_indexer[left_pos] = found_right_pos
8888

@@ -133,21 +133,21 @@ def asof_join_forward_{{on_dtype}}_by_{{by_dtype}}(
133133

134134
# find first position in right whose value is greater than left's
135135
if allow_exact_matches:
136-
while right_pos >= 0 and\
137-
right_values[right_pos] >= left_values[left_pos]:
136+
while (right_pos >= 0 and
137+
right_values[right_pos] >= left_values[left_pos]):
138138
hash_table.set_item(right_by_values[right_pos], right_pos)
139139
right_pos -= 1
140140
else:
141-
while right_pos >= 0 and\
142-
right_values[right_pos] > left_values[left_pos]:
141+
while (right_pos >= 0 and
142+
right_values[right_pos] > left_values[left_pos]):
143143
hash_table.set_item(right_by_values[right_pos], right_pos)
144144
right_pos -= 1
145145
right_pos += 1
146146

147147
# save positions as the desired index
148148
by_value = left_by_values[left_pos]
149-
found_right_pos = hash_table.get_item(by_value)\
150-
if by_value in hash_table else -1
149+
found_right_pos = (hash_table.get_item(by_value)
150+
if by_value in hash_table else -1)
151151
left_indexer[left_pos] = left_pos
152152
right_indexer[left_pos] = found_right_pos
153153

@@ -259,12 +259,12 @@ def asof_join_backward_{{on_dtype}}(
259259

260260
# find last position in right whose value is less than left's
261261
if allow_exact_matches:
262-
while right_pos < right_size and\
263-
right_values[right_pos] <= left_values[left_pos]:
262+
while (right_pos < right_size and
263+
right_values[right_pos] <= left_values[left_pos]):
264264
right_pos += 1
265265
else:
266-
while right_pos < right_size and\
267-
right_values[right_pos] < left_values[left_pos]:
266+
while (right_pos < right_size and
267+
right_values[right_pos] < left_values[left_pos]):
268268
right_pos += 1
269269
right_pos -= 1
270270

@@ -313,19 +313,19 @@ def asof_join_forward_{{on_dtype}}(
313313

314314
# find first position in right whose value is greater than left's
315315
if allow_exact_matches:
316-
while right_pos >= 0 and\
317-
right_values[right_pos] >= left_values[left_pos]:
316+
while (right_pos >= 0 and
317+
right_values[right_pos] >= left_values[left_pos]):
318318
right_pos -= 1
319319
else:
320-
while right_pos >= 0 and\
321-
right_values[right_pos] > left_values[left_pos]:
320+
while (right_pos >= 0 and
321+
right_values[right_pos] > left_values[left_pos]):
322322
right_pos -= 1
323323
right_pos += 1
324324

325325
# save positions as the desired index
326326
left_indexer[left_pos] = left_pos
327-
right_indexer[left_pos] = right_pos\
328-
if right_pos != right_size else -1
327+
right_indexer[left_pos] = (right_pos
328+
if right_pos != right_size else -1)
329329

330330
# if needed, verify that tolerance is met
331331
if has_tolerance and right_pos != right_size:

pandas/_libs/join_helper.pxi.in

+38-23
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,30 @@ Template for each `dtype` helper function for join
44
WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
55
"""
66

7-
#----------------------------------------------------------------------
7+
# ----------------------------------------------------------------------
88
# left_join_indexer, inner_join_indexer, outer_join_indexer
9-
#----------------------------------------------------------------------
9+
# ----------------------------------------------------------------------
1010

11-
{{py:
12-
13-
# name, c_type, dtype
14-
dtypes = [('float64', 'float64_t', 'np.float64'),
15-
('float32', 'float32_t', 'np.float32'),
16-
('object', 'object', 'object'),
17-
('int32', 'int32_t', 'np.int32'),
18-
('int64', 'int64_t', 'np.int64'),
19-
('uint64', 'uint64_t', 'np.uint64')]
20-
21-
def get_dispatch(dtypes):
22-
23-
for name, c_type, dtype in dtypes:
24-
yield name, c_type, dtype
25-
26-
}}
11+
ctypedef fused join_t:
12+
float64_t
13+
float32_t
14+
object
15+
int32_t
16+
int64_t
17+
uint64_t
2718

28-
{{for name, c_type, dtype in get_dispatch(dtypes)}}
2919

3020
# Joins on ordered, unique indices
3121

3222
# right might contain non-unique values
3323

34-
3524
@cython.wraparound(False)
3625
@cython.boundscheck(False)
37-
def left_join_indexer_unique_{{name}}(ndarray[{{c_type}}] left,
38-
ndarray[{{c_type}}] right):
26+
def left_join_indexer_unique(ndarray[join_t] left, ndarray[join_t] right):
3927
cdef:
4028
Py_ssize_t i, j, nleft, nright
4129
ndarray[int64_t] indexer
42-
{{c_type}} lval, rval
30+
join_t lval, rval
4331

4432
i = 0
4533
j = 0
@@ -78,6 +66,33 @@ def left_join_indexer_unique_{{name}}(ndarray[{{c_type}}] left,
7866
return indexer
7967

8068

69+
left_join_indexer_unique_float64 = left_join_indexer_unique["float64_t"]
70+
left_join_indexer_unique_float32 = left_join_indexer_unique["float32_t"]
71+
left_join_indexer_unique_object = left_join_indexer_unique["object"]
72+
left_join_indexer_unique_int32 = left_join_indexer_unique["int32_t"]
73+
left_join_indexer_unique_int64 = left_join_indexer_unique["int64_t"]
74+
left_join_indexer_unique_uint64 = left_join_indexer_unique["uint64_t"]
75+
76+
77+
{{py:
78+
79+
# name, c_type, dtype
80+
dtypes = [('float64', 'float64_t', 'np.float64'),
81+
('float32', 'float32_t', 'np.float32'),
82+
('object', 'object', 'object'),
83+
('int32', 'int32_t', 'np.int32'),
84+
('int64', 'int64_t', 'np.int64'),
85+
('uint64', 'uint64_t', 'np.uint64')]
86+
87+
def get_dispatch(dtypes):
88+
89+
for name, c_type, dtype in dtypes:
90+
yield name, c_type, dtype
91+
92+
}}
93+
94+
{{for name, c_type, dtype in get_dispatch(dtypes)}}
95+
8196
# @cython.wraparound(False)
8297
# @cython.boundscheck(False)
8398
def left_join_indexer_{{name}}(ndarray[{{c_type}}] left,

0 commit comments

Comments
 (0)