@@ -29,7 +29,8 @@ from hashtable cimport *
29
29
{{for on_dtype in on_dtypes}}
30
30
31
31
32
- def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values,
32
+ def asof_join_backward_{{on_dtype}}_by_{{by_dtype}}(
33
+ ndarray[{{on_dtype}}] left_values,
33
34
ndarray[{{on_dtype}}] right_values,
34
35
ndarray[{{by_dtype}}] left_by_values,
35
36
ndarray[{{by_dtype}}] right_by_values,
@@ -41,6 +42,7 @@ def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values,
41
42
ndarray[int64_t] left_indexer, right_indexer
42
43
bint has_tolerance = 0
43
44
{{on_dtype}} tolerance_
45
+ {{on_dtype}} diff
44
46
{{table_type}} hash_table
45
47
{{by_dtype}} by_value
46
48
@@ -63,7 +65,7 @@ def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values,
63
65
if right_pos < 0:
64
66
right_pos = 0
65
67
66
- # find last position in right whose value is less than left's value
68
+ # find last position in right whose value is less than left's
67
69
if allow_exact_matches:
68
70
while right_pos < right_size and\
69
71
right_values[right_pos] <= left_values[left_pos]:
@@ -91,6 +93,119 @@ def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values,
91
93
92
94
return left_indexer, right_indexer
93
95
96
+
97
+ def asof_join_forward_{{on_dtype}}_by_{{by_dtype}}(
98
+ ndarray[{{on_dtype}}] left_values,
99
+ ndarray[{{on_dtype}}] right_values,
100
+ ndarray[{{by_dtype}}] left_by_values,
101
+ ndarray[{{by_dtype}}] right_by_values,
102
+ bint allow_exact_matches=1,
103
+ tolerance=None):
104
+
105
+ cdef:
106
+ Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
107
+ ndarray[int64_t] left_indexer, right_indexer
108
+ bint has_tolerance = 0
109
+ {{on_dtype}} tolerance_
110
+ {{on_dtype}} diff
111
+ {{table_type}} hash_table
112
+ {{by_dtype}} by_value
113
+
114
+ # if we are using tolerance, set our objects
115
+ if tolerance is not None:
116
+ has_tolerance = 1
117
+ tolerance_ = tolerance
118
+
119
+ left_size = len(left_values)
120
+ right_size = len(right_values)
121
+
122
+ left_indexer = np.empty(left_size, dtype=np.int64)
123
+ right_indexer = np.empty(left_size, dtype=np.int64)
124
+
125
+ hash_table = {{table_type}}(right_size)
126
+
127
+ right_pos = right_size - 1
128
+ for left_pos in range(left_size - 1, -1, -1):
129
+ # restart right_pos if it went over in a previous iteration
130
+ if right_pos == right_size:
131
+ right_pos = right_size - 1
132
+
133
+ # find first position in right whose value is greater than left's
134
+ if allow_exact_matches:
135
+ while right_pos >= 0 and\
136
+ right_values[right_pos] >= left_values[left_pos]:
137
+ hash_table.set_item(right_by_values[right_pos], right_pos)
138
+ right_pos -= 1
139
+ else:
140
+ while right_pos >= 0 and\
141
+ right_values[right_pos] > left_values[left_pos]:
142
+ hash_table.set_item(right_by_values[right_pos], right_pos)
143
+ right_pos -= 1
144
+ right_pos += 1
145
+
146
+ # save positions as the desired index
147
+ by_value = left_by_values[left_pos]
148
+ found_right_pos = hash_table.get_item(by_value)\
149
+ if by_value in hash_table else -1
150
+ left_indexer[left_pos] = left_pos
151
+ right_indexer[left_pos] = found_right_pos
152
+
153
+ # if needed, verify that tolerance is met
154
+ if has_tolerance and found_right_pos != -1:
155
+ diff = right_values[found_right_pos] - left_values[left_pos]
156
+ if diff > tolerance_:
157
+ right_indexer[left_pos] = -1
158
+
159
+ return left_indexer, right_indexer
160
+
161
+
162
+ def asof_join_nearest_{{on_dtype}}_by_{{by_dtype}}(
163
+ ndarray[{{on_dtype}}] left_values,
164
+ ndarray[{{on_dtype}}] right_values,
165
+ ndarray[{{by_dtype}}] left_by_values,
166
+ ndarray[{{by_dtype}}] right_by_values,
167
+ bint allow_exact_matches=1,
168
+ tolerance=None):
169
+
170
+ cdef:
171
+ Py_ssize_t left_size, right_size, i
172
+ ndarray[int64_t] left_indexer, right_indexer, bli, bri, fli, fri
173
+ {{on_dtype}} bdiff, fdiff
174
+
175
+ left_size = len(left_values)
176
+ right_size = len(right_values)
177
+
178
+ left_indexer = np.empty(left_size, dtype=np.int64)
179
+ right_indexer = np.empty(left_size, dtype=np.int64)
180
+
181
+ # search both forward and backward
182
+ bli, bri =\
183
+ asof_join_backward_{{on_dtype}}_by_{{by_dtype}}(left_values,
184
+ right_values,
185
+ left_by_values,
186
+ right_by_values,
187
+ allow_exact_matches,
188
+ tolerance)
189
+ fli, fri =\
190
+ asof_join_forward_{{on_dtype}}_by_{{by_dtype}}(left_values,
191
+ right_values,
192
+ left_by_values,
193
+ right_by_values,
194
+ allow_exact_matches,
195
+ tolerance)
196
+
197
+ for i in range(len(bri)):
198
+ # choose timestamp from right with smaller difference
199
+ if bri[i] != -1 and fri[i] != -1:
200
+ bdiff = left_values[bli[i]] - right_values[bri[i]]
201
+ fdiff = right_values[fri[i]] - left_values[fli[i]]
202
+ right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
203
+ else:
204
+ right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
205
+ left_indexer[i] = bli[i]
206
+
207
+ return left_indexer, right_indexer
208
+
94
209
{{endfor}}
95
210
{{endfor}}
96
211
@@ -111,7 +226,8 @@ dtypes = ['uint8_t', 'uint16_t', 'uint32_t', 'uint64_t',
111
226
{{for on_dtype in dtypes}}
112
227
113
228
114
- def asof_join_{{on_dtype}}(ndarray[{{on_dtype}}] left_values,
229
+ def asof_join_backward_{{on_dtype}}(
230
+ ndarray[{{on_dtype}}] left_values,
115
231
ndarray[{{on_dtype}}] right_values,
116
232
bint allow_exact_matches=1,
117
233
tolerance=None):
@@ -121,6 +237,7 @@ def asof_join_{{on_dtype}}(ndarray[{{on_dtype}}] left_values,
121
237
ndarray[int64_t] left_indexer, right_indexer
122
238
bint has_tolerance = 0
123
239
{{on_dtype}} tolerance_
240
+ {{on_dtype}} diff
124
241
125
242
# if we are using tolerance, set our objects
126
243
if tolerance is not None:
@@ -139,7 +256,7 @@ def asof_join_{{on_dtype}}(ndarray[{{on_dtype}}] left_values,
139
256
if right_pos < 0:
140
257
right_pos = 0
141
258
142
- # find last position in right whose value is less than left's value
259
+ # find last position in right whose value is less than left's
143
260
if allow_exact_matches:
144
261
while right_pos < right_size and\
145
262
right_values[right_pos] <= left_values[left_pos]:
@@ -162,5 +279,96 @@ def asof_join_{{on_dtype}}(ndarray[{{on_dtype}}] left_values,
162
279
163
280
return left_indexer, right_indexer
164
281
282
+
283
+ def asof_join_forward_{{on_dtype}}(
284
+ ndarray[{{on_dtype}}] left_values,
285
+ ndarray[{{on_dtype}}] right_values,
286
+ bint allow_exact_matches=1,
287
+ tolerance=None):
288
+
289
+ cdef:
290
+ Py_ssize_t left_pos, right_pos, left_size, right_size
291
+ ndarray[int64_t] left_indexer, right_indexer
292
+ bint has_tolerance = 0
293
+ {{on_dtype}} tolerance_
294
+ {{on_dtype}} diff
295
+
296
+ # if we are using tolerance, set our objects
297
+ if tolerance is not None:
298
+ has_tolerance = 1
299
+ tolerance_ = tolerance
300
+
301
+ left_size = len(left_values)
302
+ right_size = len(right_values)
303
+
304
+ left_indexer = np.empty(left_size, dtype=np.int64)
305
+ right_indexer = np.empty(left_size, dtype=np.int64)
306
+
307
+ right_pos = right_size - 1
308
+ for left_pos in range(left_size - 1, -1, -1):
309
+ # restart right_pos if it went over in a previous iteration
310
+ if right_pos == right_size:
311
+ right_pos = right_size - 1
312
+
313
+ # find first position in right whose value is greater than left's
314
+ if allow_exact_matches:
315
+ while right_pos >= 0 and\
316
+ right_values[right_pos] >= left_values[left_pos]:
317
+ right_pos -= 1
318
+ else:
319
+ while right_pos >= 0 and\
320
+ right_values[right_pos] > left_values[left_pos]:
321
+ right_pos -= 1
322
+ right_pos += 1
323
+
324
+ # save positions as the desired index
325
+ left_indexer[left_pos] = left_pos
326
+ right_indexer[left_pos] = right_pos\
327
+ if right_pos != right_size else -1
328
+
329
+ # if needed, verify that tolerance is met
330
+ if has_tolerance and right_pos != right_size:
331
+ diff = right_values[right_pos] - left_values[left_pos]
332
+ if diff > tolerance_:
333
+ right_indexer[left_pos] = -1
334
+
335
+ return left_indexer, right_indexer
336
+
337
+
338
+ def asof_join_nearest_{{on_dtype}}(
339
+ ndarray[{{on_dtype}}] left_values,
340
+ ndarray[{{on_dtype}}] right_values,
341
+ bint allow_exact_matches=1,
342
+ tolerance=None):
343
+
344
+ cdef:
345
+ Py_ssize_t left_size, right_size, i
346
+ ndarray[int64_t] left_indexer, right_indexer, bli, bri, fli, fri
347
+ {{on_dtype}} bdiff, fdiff
348
+
349
+ left_size = len(left_values)
350
+ right_size = len(right_values)
351
+
352
+ left_indexer = np.empty(left_size, dtype=np.int64)
353
+ right_indexer = np.empty(left_size, dtype=np.int64)
354
+
355
+ # search both forward and backward
356
+ bli, bri = asof_join_backward_{{on_dtype}}(left_values, right_values,
357
+ allow_exact_matches, tolerance)
358
+ fli, fri = asof_join_forward_{{on_dtype}}(left_values, right_values,
359
+ allow_exact_matches, tolerance)
360
+
361
+ for i in range(len(bri)):
362
+ # choose timestamp from right with smaller difference
363
+ if bri[i] != -1 and fri[i] != -1:
364
+ bdiff = left_values[bli[i]] - right_values[bri[i]]
365
+ fdiff = right_values[fri[i]] - left_values[fli[i]]
366
+ right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
367
+ else:
368
+ right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
369
+ left_indexer[i] = bli[i]
370
+
371
+ return left_indexer, right_indexer
372
+
165
373
{{endfor}}
166
374
0 commit comments