@@ -33,13 +33,15 @@ def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values,
33
33
ndarray[{{by_dtype}}] left_by_values,
34
34
ndarray[{{by_dtype}}] right_by_values,
35
35
bint allow_exact_matches=1,
36
- tolerance=None):
36
+ tolerance=None,
37
+ int64_t direction_enum=0):
37
38
38
39
cdef:
39
40
Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
40
- ndarray[int64_t] left_indexer, right_indexer
41
+ ndarray[int64_t] left_indexer, right_indexer, bli, bri, fli, fri
41
42
bint has_tolerance = 0
42
43
{{on_dtype}} tolerance_
44
+ {{on_dtype}} diff, bdiff, fdiff
43
45
{{table_type}} hash_table
44
46
{{by_dtype}} by_value
45
47
@@ -56,37 +58,94 @@ def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values,
56
58
57
59
hash_table = {{table_type}}(right_size)
58
60
59
- right_pos = 0
60
- for left_pos in range(left_size):
61
- # restart right_pos if it went negative in a previous iteration
62
- if right_pos < 0:
63
- right_pos = 0
64
-
65
- # find last position in right whose value is less than left's value
66
- if allow_exact_matches:
67
- while right_pos < right_size and\
68
- right_values[right_pos] <= left_values[left_pos]:
69
- hash_table.set_item(right_by_values[right_pos], right_pos)
70
- right_pos += 1
71
- else:
72
- while right_pos < right_size and\
73
- right_values[right_pos] < left_values[left_pos]:
74
- hash_table.set_item(right_by_values[right_pos], right_pos)
75
- right_pos += 1
76
- right_pos -= 1
77
-
78
- # save positions as the desired index
79
- by_value = left_by_values[left_pos]
80
- found_right_pos = hash_table.get_item(by_value)\
81
- if by_value in hash_table else -1
82
- left_indexer[left_pos] = left_pos
83
- right_indexer[left_pos] = found_right_pos
84
-
85
- # if needed, verify that tolerance is met
86
- if has_tolerance and found_right_pos != -1:
87
- diff = left_values[left_pos] - right_values[found_right_pos]
88
- if diff > tolerance_:
89
- right_indexer[left_pos] = -1
61
+ if direction_enum == 0: #backward
62
+ right_pos = 0
63
+ for left_pos in range(left_size):
64
+ # restart right_pos if it went negative in a previous iteration
65
+ if right_pos < 0:
66
+ right_pos = 0
67
+
68
+ # find last position in right whose value is less than left's
69
+ if allow_exact_matches:
70
+ while right_pos < right_size and\
71
+ right_values[right_pos] <= left_values[left_pos]:
72
+ hash_table.set_item(right_by_values[right_pos], right_pos)
73
+ right_pos += 1
74
+ else:
75
+ while right_pos < right_size and\
76
+ right_values[right_pos] < left_values[left_pos]:
77
+ hash_table.set_item(right_by_values[right_pos], right_pos)
78
+ right_pos += 1
79
+ right_pos -= 1
80
+
81
+ # save positions as the desired index
82
+ by_value = left_by_values[left_pos]
83
+ found_right_pos = hash_table.get_item(by_value)\
84
+ if by_value in hash_table else -1
85
+ left_indexer[left_pos] = left_pos
86
+ right_indexer[left_pos] = found_right_pos
87
+
88
+ # if needed, verify that tolerance is met
89
+ if has_tolerance and found_right_pos != -1:
90
+ diff = left_values[left_pos] - right_values[found_right_pos]
91
+ if diff > tolerance_:
92
+ right_indexer[left_pos] = -1
93
+ elif direction_enum == 1: # forward
94
+ right_pos = right_size - 1
95
+ for left_pos in range(left_size - 1, -1, -1):
96
+ # restart right_pos if it went over in a previous iteration
97
+ if right_pos == right_size:
98
+ right_pos = right_size - 1
99
+
100
+ # find first position in right whose value is greater than left's
101
+ if allow_exact_matches:
102
+ while right_pos >= 0 and\
103
+ right_values[right_pos] >= left_values[left_pos]:
104
+ hash_table.set_item(right_by_values[right_pos], right_pos)
105
+ right_pos -= 1
106
+ else:
107
+ while right_pos >= 0 and\
108
+ right_values[right_pos] > left_values[left_pos]:
109
+ hash_table.set_item(right_by_values[right_pos], right_pos)
110
+ right_pos -= 1
111
+ right_pos += 1
112
+
113
+ # save positions as the desired index
114
+ by_value = left_by_values[left_pos]
115
+ found_right_pos = hash_table.get_item(by_value)\
116
+ if by_value in hash_table else -1
117
+ left_indexer[left_pos] = left_pos
118
+ right_indexer[left_pos] = found_right_pos
119
+
120
+ # if needed, verify that tolerance is met
121
+ if has_tolerance and found_right_pos != -1:
122
+ diff = right_values[found_right_pos] - left_values[left_pos]
123
+ if diff > tolerance_:
124
+ right_indexer[left_pos] = -1
125
+ else: # nearest
126
+ # search both forward and backward
127
+ bli, bri = asof_join_{{on_dtype}}_by_{{by_dtype}}(left_values,
128
+ right_values,
129
+ left_by_values,
130
+ right_by_values,
131
+ allow_exact_matches,
132
+ tolerance, 0)
133
+ fli, fri = asof_join_{{on_dtype}}_by_{{by_dtype}}(left_values,
134
+ right_values,
135
+ left_by_values,
136
+ right_by_values,
137
+ allow_exact_matches,
138
+ tolerance, 1)
139
+
140
+ for i in range(len(bri)):
141
+ # choose timestamp from right with smaller difference
142
+ if bri[i] != -1 and fri[i] != -1:
143
+ bdiff = left_values[bli[i]] - right_values[bri[i]]
144
+ fdiff = right_values[fri[i]] - left_values[fli[i]]
145
+ right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
146
+ else:
147
+ right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
148
+ left_indexer[i] = bli[i]
90
149
91
150
return left_indexer, right_indexer
92
151
@@ -113,13 +172,15 @@ dtypes = ['uint8_t', 'uint16_t', 'uint32_t', 'uint64_t',
113
172
def asof_join_{{on_dtype}}(ndarray[{{on_dtype}}] left_values,
114
173
ndarray[{{on_dtype}}] right_values,
115
174
bint allow_exact_matches=1,
116
- tolerance=None):
175
+ tolerance=None,
176
+ int64_t direction_enum=0):
117
177
118
178
cdef:
119
179
Py_ssize_t left_pos, right_pos, left_size, right_size
120
- ndarray[int64_t] left_indexer, right_indexer
180
+ ndarray[int64_t] left_indexer, right_indexer, bli, bri, fli, fri
121
181
bint has_tolerance = 0
122
182
{{on_dtype}} tolerance_
183
+ {{on_dtype}} diff, bdiff, fdiff
123
184
124
185
# if we are using tolerance, set our objects
125
186
if tolerance is not None:
@@ -132,32 +193,77 @@ def asof_join_{{on_dtype}}(ndarray[{{on_dtype}}] left_values,
132
193
left_indexer = np.empty(left_size, dtype=np.int64)
133
194
right_indexer = np.empty(left_size, dtype=np.int64)
134
195
135
- right_pos = 0
136
- for left_pos in range(left_size):
137
- # restart right_pos if it went negative in a previous iteration
138
- if right_pos < 0:
139
- right_pos = 0
140
-
141
- # find last position in right whose value is less than left's value
142
- if allow_exact_matches:
143
- while right_pos < right_size and\
144
- right_values[right_pos] <= left_values[left_pos]:
145
- right_pos += 1
146
- else:
147
- while right_pos < right_size and\
148
- right_values[right_pos] < left_values[left_pos]:
149
- right_pos += 1
150
- right_pos -= 1
151
-
152
- # save positions as the desired index
153
- left_indexer[left_pos] = left_pos
154
- right_indexer[left_pos] = right_pos
155
-
156
- # if needed, verify that tolerance is met
157
- if has_tolerance and right_pos != -1:
158
- diff = left_values[left_pos] - right_values[right_pos]
159
- if diff > tolerance_:
160
- right_indexer[left_pos] = -1
196
+ if direction_enum == 0: # backward
197
+ right_pos = 0
198
+ for left_pos in range(left_size):
199
+ # restart right_pos if it went negative in a previous iteration
200
+ if right_pos < 0:
201
+ right_pos = 0
202
+
203
+ # find last position in right whose value is less than left's
204
+ if allow_exact_matches:
205
+ while right_pos < right_size and\
206
+ right_values[right_pos] <= left_values[left_pos]:
207
+ right_pos += 1
208
+ else:
209
+ while right_pos < right_size and\
210
+ right_values[right_pos] < left_values[left_pos]:
211
+ right_pos += 1
212
+ right_pos -= 1
213
+
214
+ # save positions as the desired index
215
+ left_indexer[left_pos] = left_pos
216
+ right_indexer[left_pos] = right_pos
217
+
218
+ # if needed, verify that tolerance is met
219
+ if has_tolerance and right_pos != -1:
220
+ diff = left_values[left_pos] - right_values[right_pos]
221
+ if diff > tolerance_:
222
+ right_indexer[left_pos] = -1
223
+ elif direction_enum == 1: # forward
224
+ right_pos = right_size - 1
225
+ for left_pos in range(left_size - 1, -1, -1):
226
+ # restart right_pos if it went over in a previous iteration
227
+ if right_pos == right_size:
228
+ right_pos = right_size - 1
229
+
230
+ # find first position in right whose value is greater than left's
231
+ if allow_exact_matches:
232
+ while right_pos >= 0 and\
233
+ right_values[right_pos] >= left_values[left_pos]:
234
+ right_pos -= 1
235
+ else:
236
+ while right_pos >= 0 and\
237
+ right_values[right_pos] > left_values[left_pos]:
238
+ right_pos -= 1
239
+ right_pos += 1
240
+
241
+ # save positions as the desired index
242
+ left_indexer[left_pos] = left_pos
243
+ right_indexer[left_pos] = right_pos\
244
+ if right_pos != right_size else -1
245
+
246
+ # if needed, verify that tolerance is met
247
+ if has_tolerance and right_pos != right_size:
248
+ diff = right_values[right_pos] - left_values[left_pos]
249
+ if diff > tolerance_:
250
+ right_indexer[left_pos] = -1
251
+ else: # nearest
252
+ # search both forward and backward
253
+ bli, bri = asof_join_{{on_dtype}}(left_values, right_values,
254
+ allow_exact_matches, tolerance, 0)
255
+ fli, fri = asof_join_{{on_dtype}}(left_values, right_values,
256
+ allow_exact_matches, tolerance, 1)
257
+
258
+ for i in range(len(bri)):
259
+ # choose timestamp from right with smaller difference
260
+ if bri[i] != -1 and fri[i] != -1:
261
+ bdiff = left_values[bli[i]] - right_values[bri[i]]
262
+ fdiff = right_values[fri[i]] - left_values[fli[i]]
263
+ right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
264
+ else:
265
+ right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
266
+ left_indexer[i] = bli[i]
161
267
162
268
return left_indexer, right_indexer
163
269
0 commit comments