Skip to content

Commit 01e5c40

Browse files
committed
Use fused types for asof joins
1 parent f6c264c commit 01e5c40

File tree

1 file changed

+72
-40
lines changed

1 file changed

+72
-40
lines changed

pandas/_libs/join_func_helper.pxi.in

+72-40
Original file line numberDiff line numberDiff line change
@@ -210,34 +210,34 @@ def asof_join_nearest_{{on_dtype}}_by_{{by_dtype}}(
210210
{{endfor}}
211211

212212

213-
#----------------------------------------------------------------------
213+
# ----------------------------------------------------------------------
214214
# asof_join
215-
#----------------------------------------------------------------------
216-
217-
{{py:
218-
219-
# on_dtype
220-
dtypes = ['uint8_t', 'uint16_t', 'uint32_t', 'uint64_t',
221-
'int8_t', 'int16_t', 'int32_t', 'int64_t',
222-
'float', 'double']
223-
224-
}}
225-
226-
{{for on_dtype in dtypes}}
227-
228-
229-
def asof_join_backward_{{on_dtype}}(
230-
ndarray[{{on_dtype}}] left_values,
231-
ndarray[{{on_dtype}}] right_values,
232-
bint allow_exact_matches=1,
233-
tolerance=None):
215+
# ----------------------------------------------------------------------
216+
217+
ctypedef fused asof_t:
218+
uint8_t
219+
uint16_t
220+
uint32_t
221+
uint64_t
222+
int8_t
223+
int16_t
224+
int32_t
225+
int64_t
226+
float
227+
double
228+
229+
230+
def asof_join_backward(ndarray[asof_t] left_values,
231+
ndarray[asof_t] right_values,
232+
bint allow_exact_matches=1,
233+
tolerance=None):
234234

235235
cdef:
236236
Py_ssize_t left_pos, right_pos, left_size, right_size
237237
ndarray[int64_t] left_indexer, right_indexer
238238
bint has_tolerance = 0
239-
{{on_dtype}} tolerance_ = 0
240-
{{on_dtype}} diff = 0
239+
asof_t tolerance_ = 0
240+
asof_t diff = 0
241241

242242
# if we are using tolerance, set our objects
243243
if tolerance is not None:
@@ -280,18 +280,29 @@ def asof_join_backward_{{on_dtype}}(
280280
return left_indexer, right_indexer
281281

282282

283-
def asof_join_forward_{{on_dtype}}(
284-
ndarray[{{on_dtype}}] left_values,
285-
ndarray[{{on_dtype}}] right_values,
286-
bint allow_exact_matches=1,
287-
tolerance=None):
283+
asof_join_backward_uint8_t = asof_join_backward["uint8_t"]
284+
asof_join_backward_uint16_t = asof_join_backward["uint16_t"]
285+
asof_join_backward_uint32_t = asof_join_backward["uint32_t"]
286+
asof_join_backward_uint64_t = asof_join_backward["uint64_t"]
287+
asof_join_backward_int8_t = asof_join_backward["int8_t"]
288+
asof_join_backward_int16_t = asof_join_backward["int16_t"]
289+
asof_join_backward_int32_t = asof_join_backward["int32_t"]
290+
asof_join_backward_int64_t = asof_join_backward["int64_t"]
291+
asof_join_backward_float = asof_join_backward["float"]
292+
asof_join_backward_double = asof_join_backward["double"]
293+
294+
295+
def asof_join_forward(ndarray[asof_t] left_values,
296+
ndarray[asof_t] right_values,
297+
bint allow_exact_matches=1,
298+
tolerance=None):
288299

289300
cdef:
290301
Py_ssize_t left_pos, right_pos, left_size, right_size
291302
ndarray[int64_t] left_indexer, right_indexer
292303
bint has_tolerance = 0
293-
{{on_dtype}} tolerance_ = 0
294-
{{on_dtype}} diff = 0
304+
asof_t tolerance_ = 0
305+
asof_t diff = 0
295306

296307
# if we are using tolerance, set our objects
297308
if tolerance is not None:
@@ -335,16 +346,27 @@ def asof_join_forward_{{on_dtype}}(
335346
return left_indexer, right_indexer
336347

337348

338-
def asof_join_nearest_{{on_dtype}}(
339-
ndarray[{{on_dtype}}] left_values,
340-
ndarray[{{on_dtype}}] right_values,
341-
bint allow_exact_matches=1,
342-
tolerance=None):
349+
asof_join_forward_uint8_t = asof_join_forward["uint8_t"]
350+
asof_join_forward_uint16_t = asof_join_forward["uint16_t"]
351+
asof_join_forward_uint32_t = asof_join_forward["uint32_t"]
352+
asof_join_forward_uint64_t = asof_join_forward["uint64_t"]
353+
asof_join_forward_int8_t = asof_join_forward["int8_t"]
354+
asof_join_forward_int16_t = asof_join_forward["int16_t"]
355+
asof_join_forward_int32_t = asof_join_forward["int32_t"]
356+
asof_join_forward_int64_t = asof_join_forward["int64_t"]
357+
asof_join_forward_float = asof_join_forward["float"]
358+
asof_join_forward_double = asof_join_forward["double"]
359+
360+
361+
def asof_join_nearest(ndarray[asof_t] left_values,
362+
ndarray[asof_t] right_values,
363+
bint allow_exact_matches=1,
364+
tolerance=None):
343365

344366
cdef:
345367
Py_ssize_t left_size, right_size, i
346368
ndarray[int64_t] left_indexer, right_indexer, bli, bri, fli, fri
347-
{{on_dtype}} bdiff, fdiff
369+
asof_t bdiff, fdiff
348370

349371
left_size = len(left_values)
350372
right_size = len(right_values)
@@ -353,10 +375,10 @@ def asof_join_nearest_{{on_dtype}}(
353375
right_indexer = np.empty(left_size, dtype=np.int64)
354376

355377
# search both forward and backward
356-
bli, bri = asof_join_backward_{{on_dtype}}(left_values, right_values,
357-
allow_exact_matches, tolerance)
358-
fli, fri = asof_join_forward_{{on_dtype}}(left_values, right_values,
359-
allow_exact_matches, tolerance)
378+
bli, bri = asof_join_backward(left_values, right_values,
379+
allow_exact_matches, tolerance)
380+
fli, fri = asof_join_forward(left_values, right_values,
381+
allow_exact_matches, tolerance)
360382

361383
for i in range(len(bri)):
362384
# choose timestamp from right with smaller difference
@@ -370,4 +392,14 @@ def asof_join_nearest_{{on_dtype}}(
370392

371393
return left_indexer, right_indexer
372394

373-
{{endfor}}
395+
396+
asof_join_nearest_uint8_t = asof_join_nearest["uint8_t"]
397+
asof_join_nearest_uint16_t = asof_join_nearest["uint16_t"]
398+
asof_join_nearest_uint32_t = asof_join_nearest["uint32_t"]
399+
asof_join_nearest_uint64_t = asof_join_nearest["uint64_t"]
400+
asof_join_nearest_int8_t = asof_join_nearest["int8_t"]
401+
asof_join_nearest_int16_t = asof_join_nearest["int16_t"]
402+
asof_join_nearest_int32_t = asof_join_nearest["int32_t"]
403+
asof_join_nearest_int64_t = asof_join_nearest["int64_t"]
404+
asof_join_nearest_float = asof_join_nearest["float"]
405+
asof_join_nearest_double = asof_join_nearest["double"]

0 commit comments

Comments
 (0)