Skip to content

Commit d5629b1

Browse files
jbrockmendeltm9k1
authored andcommitted
REF: use fused types for the rest of libjoin (pandas-dev#23214)
1 parent 56401a7 commit d5629b1

File tree

4 files changed

+364
-446
lines changed

4 files changed

+364
-446
lines changed

pandas/_libs/join.pyx

+356-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ cdef double nan = NaN
1717
from pandas._libs.algos import groupsort_indexer, ensure_platform_int
1818
from pandas.core.algorithms import take_nd
1919

20-
include "join_func_helper.pxi"
21-
2220

2321
def inner_join(ndarray[int64_t] left, ndarray[int64_t] right,
2422
Py_ssize_t max_groups):
@@ -309,8 +307,8 @@ left_join_indexer_unique_int64 = left_join_indexer_unique["int64_t"]
309307
left_join_indexer_unique_uint64 = left_join_indexer_unique["uint64_t"]
310308

311309

312-
# @cython.wraparound(False)
313-
# @cython.boundscheck(False)
310+
@cython.wraparound(False)
311+
@cython.boundscheck(False)
314312
def left_join_indexer(ndarray[join_t] left, ndarray[join_t] right):
315313
"""
316314
Two-pass algorithm for monotonic indexes. Handles many-to-one merges
@@ -656,3 +654,357 @@ outer_join_indexer_object = outer_join_indexer["object"]
656654
outer_join_indexer_int32 = outer_join_indexer["int32_t"]
657655
outer_join_indexer_int64 = outer_join_indexer["int64_t"]
658656
outer_join_indexer_uint64 = outer_join_indexer["uint64_t"]
657+
658+
659+
# ----------------------------------------------------------------------
660+
# asof_join_by
661+
# ----------------------------------------------------------------------
662+
663+
from hashtable cimport (
664+
HashTable, PyObjectHashTable, UInt64HashTable, Int64HashTable)
665+
666+
ctypedef fused asof_t:
667+
uint8_t
668+
uint16_t
669+
uint32_t
670+
uint64_t
671+
int8_t
672+
int16_t
673+
int32_t
674+
int64_t
675+
float
676+
double
677+
678+
ctypedef fused by_t:
679+
object
680+
int64_t
681+
uint64_t
682+
683+
684+
def asof_join_backward_on_X_by_Y(ndarray[asof_t] left_values,
685+
ndarray[asof_t] right_values,
686+
ndarray[by_t] left_by_values,
687+
ndarray[by_t] right_by_values,
688+
bint allow_exact_matches=1,
689+
tolerance=None):
690+
691+
cdef:
692+
Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
693+
ndarray[int64_t] left_indexer, right_indexer
694+
bint has_tolerance = 0
695+
asof_t tolerance_ = 0
696+
asof_t diff = 0
697+
HashTable hash_table
698+
by_t by_value
699+
700+
# if we are using tolerance, set our objects
701+
if tolerance is not None:
702+
has_tolerance = 1
703+
tolerance_ = tolerance
704+
705+
left_size = len(left_values)
706+
right_size = len(right_values)
707+
708+
left_indexer = np.empty(left_size, dtype=np.int64)
709+
right_indexer = np.empty(left_size, dtype=np.int64)
710+
711+
if by_t is object:
712+
hash_table = PyObjectHashTable(right_size)
713+
elif by_t is int64_t:
714+
hash_table = Int64HashTable(right_size)
715+
elif by_t is uint64_t:
716+
hash_table = UInt64HashTable(right_size)
717+
718+
right_pos = 0
719+
for left_pos in range(left_size):
720+
# restart right_pos if it went negative in a previous iteration
721+
if right_pos < 0:
722+
right_pos = 0
723+
724+
# find last position in right whose value is less than left's
725+
if allow_exact_matches:
726+
while (right_pos < right_size and
727+
right_values[right_pos] <= left_values[left_pos]):
728+
hash_table.set_item(right_by_values[right_pos], right_pos)
729+
right_pos += 1
730+
else:
731+
while (right_pos < right_size and
732+
right_values[right_pos] < left_values[left_pos]):
733+
hash_table.set_item(right_by_values[right_pos], right_pos)
734+
right_pos += 1
735+
right_pos -= 1
736+
737+
# save positions as the desired index
738+
by_value = left_by_values[left_pos]
739+
found_right_pos = (hash_table.get_item(by_value)
740+
if by_value in hash_table else -1)
741+
left_indexer[left_pos] = left_pos
742+
right_indexer[left_pos] = found_right_pos
743+
744+
# if needed, verify that tolerance is met
745+
if has_tolerance and found_right_pos != -1:
746+
diff = left_values[left_pos] - right_values[found_right_pos]
747+
if diff > tolerance_:
748+
right_indexer[left_pos] = -1
749+
750+
return left_indexer, right_indexer
751+
752+
753+
def asof_join_forward_on_X_by_Y(ndarray[asof_t] left_values,
754+
ndarray[asof_t] right_values,
755+
ndarray[by_t] left_by_values,
756+
ndarray[by_t] right_by_values,
757+
bint allow_exact_matches=1,
758+
tolerance=None):
759+
760+
cdef:
761+
Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
762+
ndarray[int64_t] left_indexer, right_indexer
763+
bint has_tolerance = 0
764+
asof_t tolerance_ = 0
765+
asof_t diff = 0
766+
HashTable hash_table
767+
by_t by_value
768+
769+
# if we are using tolerance, set our objects
770+
if tolerance is not None:
771+
has_tolerance = 1
772+
tolerance_ = tolerance
773+
774+
left_size = len(left_values)
775+
right_size = len(right_values)
776+
777+
left_indexer = np.empty(left_size, dtype=np.int64)
778+
right_indexer = np.empty(left_size, dtype=np.int64)
779+
780+
if by_t is object:
781+
hash_table = PyObjectHashTable(right_size)
782+
elif by_t is int64_t:
783+
hash_table = Int64HashTable(right_size)
784+
elif by_t is uint64_t:
785+
hash_table = UInt64HashTable(right_size)
786+
787+
right_pos = right_size - 1
788+
for left_pos in range(left_size - 1, -1, -1):
789+
# restart right_pos if it went over in a previous iteration
790+
if right_pos == right_size:
791+
right_pos = right_size - 1
792+
793+
# find first position in right whose value is greater than left's
794+
if allow_exact_matches:
795+
while (right_pos >= 0 and
796+
right_values[right_pos] >= left_values[left_pos]):
797+
hash_table.set_item(right_by_values[right_pos], right_pos)
798+
right_pos -= 1
799+
else:
800+
while (right_pos >= 0 and
801+
right_values[right_pos] > left_values[left_pos]):
802+
hash_table.set_item(right_by_values[right_pos], right_pos)
803+
right_pos -= 1
804+
right_pos += 1
805+
806+
# save positions as the desired index
807+
by_value = left_by_values[left_pos]
808+
found_right_pos = (hash_table.get_item(by_value)
809+
if by_value in hash_table else -1)
810+
left_indexer[left_pos] = left_pos
811+
right_indexer[left_pos] = found_right_pos
812+
813+
# if needed, verify that tolerance is met
814+
if has_tolerance and found_right_pos != -1:
815+
diff = right_values[found_right_pos] - left_values[left_pos]
816+
if diff > tolerance_:
817+
right_indexer[left_pos] = -1
818+
819+
return left_indexer, right_indexer
820+
821+
822+
def asof_join_nearest_on_X_by_Y(ndarray[asof_t] left_values,
823+
ndarray[asof_t] right_values,
824+
ndarray[by_t] left_by_values,
825+
ndarray[by_t] right_by_values,
826+
bint allow_exact_matches=1,
827+
tolerance=None):
828+
829+
cdef:
830+
Py_ssize_t left_size, right_size, i
831+
ndarray[int64_t] left_indexer, right_indexer, bli, bri, fli, fri
832+
asof_t bdiff, fdiff
833+
834+
left_size = len(left_values)
835+
right_size = len(right_values)
836+
837+
left_indexer = np.empty(left_size, dtype=np.int64)
838+
right_indexer = np.empty(left_size, dtype=np.int64)
839+
840+
# search both forward and backward
841+
bli, bri = asof_join_backward_on_X_by_Y(left_values,
842+
right_values,
843+
left_by_values,
844+
right_by_values,
845+
allow_exact_matches,
846+
tolerance)
847+
fli, fri = asof_join_forward_on_X_by_Y(left_values,
848+
right_values,
849+
left_by_values,
850+
right_by_values,
851+
allow_exact_matches,
852+
tolerance)
853+
854+
for i in range(len(bri)):
855+
# choose timestamp from right with smaller difference
856+
if bri[i] != -1 and fri[i] != -1:
857+
bdiff = left_values[bli[i]] - right_values[bri[i]]
858+
fdiff = right_values[fri[i]] - left_values[fli[i]]
859+
right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
860+
else:
861+
right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
862+
left_indexer[i] = bli[i]
863+
864+
return left_indexer, right_indexer
865+
866+
867+
# ----------------------------------------------------------------------
868+
# asof_join
869+
# ----------------------------------------------------------------------
870+
871+
def asof_join_backward(ndarray[asof_t] left_values,
872+
ndarray[asof_t] right_values,
873+
bint allow_exact_matches=1,
874+
tolerance=None):
875+
876+
cdef:
877+
Py_ssize_t left_pos, right_pos, left_size, right_size
878+
ndarray[int64_t] left_indexer, right_indexer
879+
bint has_tolerance = 0
880+
asof_t tolerance_ = 0
881+
asof_t diff = 0
882+
883+
# if we are using tolerance, set our objects
884+
if tolerance is not None:
885+
has_tolerance = 1
886+
tolerance_ = tolerance
887+
888+
left_size = len(left_values)
889+
right_size = len(right_values)
890+
891+
left_indexer = np.empty(left_size, dtype=np.int64)
892+
right_indexer = np.empty(left_size, dtype=np.int64)
893+
894+
right_pos = 0
895+
for left_pos in range(left_size):
896+
# restart right_pos if it went negative in a previous iteration
897+
if right_pos < 0:
898+
right_pos = 0
899+
900+
# find last position in right whose value is less than left's
901+
if allow_exact_matches:
902+
while (right_pos < right_size and
903+
right_values[right_pos] <= left_values[left_pos]):
904+
right_pos += 1
905+
else:
906+
while (right_pos < right_size and
907+
right_values[right_pos] < left_values[left_pos]):
908+
right_pos += 1
909+
right_pos -= 1
910+
911+
# save positions as the desired index
912+
left_indexer[left_pos] = left_pos
913+
right_indexer[left_pos] = right_pos
914+
915+
# if needed, verify that tolerance is met
916+
if has_tolerance and right_pos != -1:
917+
diff = left_values[left_pos] - right_values[right_pos]
918+
if diff > tolerance_:
919+
right_indexer[left_pos] = -1
920+
921+
return left_indexer, right_indexer
922+
923+
924+
def asof_join_forward(ndarray[asof_t] left_values,
925+
ndarray[asof_t] right_values,
926+
bint allow_exact_matches=1,
927+
tolerance=None):
928+
929+
cdef:
930+
Py_ssize_t left_pos, right_pos, left_size, right_size
931+
ndarray[int64_t] left_indexer, right_indexer
932+
bint has_tolerance = 0
933+
asof_t tolerance_ = 0
934+
asof_t diff = 0
935+
936+
# if we are using tolerance, set our objects
937+
if tolerance is not None:
938+
has_tolerance = 1
939+
tolerance_ = tolerance
940+
941+
left_size = len(left_values)
942+
right_size = len(right_values)
943+
944+
left_indexer = np.empty(left_size, dtype=np.int64)
945+
right_indexer = np.empty(left_size, dtype=np.int64)
946+
947+
right_pos = right_size - 1
948+
for left_pos in range(left_size - 1, -1, -1):
949+
# restart right_pos if it went over in a previous iteration
950+
if right_pos == right_size:
951+
right_pos = right_size - 1
952+
953+
# find first position in right whose value is greater than left's
954+
if allow_exact_matches:
955+
while (right_pos >= 0 and
956+
right_values[right_pos] >= left_values[left_pos]):
957+
right_pos -= 1
958+
else:
959+
while (right_pos >= 0 and
960+
right_values[right_pos] > left_values[left_pos]):
961+
right_pos -= 1
962+
right_pos += 1
963+
964+
# save positions as the desired index
965+
left_indexer[left_pos] = left_pos
966+
right_indexer[left_pos] = (right_pos
967+
if right_pos != right_size else -1)
968+
969+
# if needed, verify that tolerance is met
970+
if has_tolerance and right_pos != right_size:
971+
diff = right_values[right_pos] - left_values[left_pos]
972+
if diff > tolerance_:
973+
right_indexer[left_pos] = -1
974+
975+
return left_indexer, right_indexer
976+
977+
978+
def asof_join_nearest(ndarray[asof_t] left_values,
979+
ndarray[asof_t] right_values,
980+
bint allow_exact_matches=1,
981+
tolerance=None):
982+
983+
cdef:
984+
Py_ssize_t left_size, right_size, i
985+
ndarray[int64_t] left_indexer, right_indexer, bli, bri, fli, fri
986+
asof_t bdiff, fdiff
987+
988+
left_size = len(left_values)
989+
right_size = len(right_values)
990+
991+
left_indexer = np.empty(left_size, dtype=np.int64)
992+
right_indexer = np.empty(left_size, dtype=np.int64)
993+
994+
# search both forward and backward
995+
bli, bri = asof_join_backward(left_values, right_values,
996+
allow_exact_matches, tolerance)
997+
fli, fri = asof_join_forward(left_values, right_values,
998+
allow_exact_matches, tolerance)
999+
1000+
for i in range(len(bri)):
1001+
# choose timestamp from right with smaller difference
1002+
if bri[i] != -1 and fri[i] != -1:
1003+
bdiff = left_values[bli[i]] - right_values[bri[i]]
1004+
fdiff = right_values[fri[i]] - left_values[fli[i]]
1005+
right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
1006+
else:
1007+
right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
1008+
left_indexer[i] = bli[i]
1009+
1010+
return left_indexer, right_indexer

0 commit comments

Comments
 (0)