Skip to content

Commit 2fdeb07

Browse files
authored
REF: de-duplicate libjoin (pandas-dev#46256)
1 parent 0be6fd3 commit 2fdeb07

File tree

1 file changed

+89
-149
lines changed

1 file changed

+89
-149
lines changed

pandas/_libs/join.pyx

+89-149
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,13 @@ def left_outer_join(const intp_t[:] left, const intp_t[:] right,
9393
with nogil:
9494
# First pass, determine size of result set, do not use the NA group
9595
for i in range(1, max_groups + 1):
96-
if right_count[i] > 0:
97-
count += left_count[i] * right_count[i]
96+
lc = left_count[i]
97+
rc = right_count[i]
98+
99+
if rc > 0:
100+
count += lc * rc
98101
else:
99-
count += left_count[i]
102+
count += lc
100103

101104
left_indexer = np.empty(count, dtype=np.intp)
102105
right_indexer = np.empty(count, dtype=np.intp)
@@ -679,7 +682,8 @@ def asof_join_backward_on_X_by_Y(numeric_t[:] left_values,
679682
by_t[:] left_by_values,
680683
by_t[:] right_by_values,
681684
bint allow_exact_matches=True,
682-
tolerance=None):
685+
tolerance=None,
686+
bint use_hashtable=True):
683687

684688
cdef:
685689
Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
@@ -701,12 +705,13 @@ def asof_join_backward_on_X_by_Y(numeric_t[:] left_values,
701705
left_indexer = np.empty(left_size, dtype=np.intp)
702706
right_indexer = np.empty(left_size, dtype=np.intp)
703707

704-
if by_t is object:
705-
hash_table = PyObjectHashTable(right_size)
706-
elif by_t is int64_t:
707-
hash_table = Int64HashTable(right_size)
708-
elif by_t is uint64_t:
709-
hash_table = UInt64HashTable(right_size)
708+
if use_hashtable:
709+
if by_t is object:
710+
hash_table = PyObjectHashTable(right_size)
711+
elif by_t is int64_t:
712+
hash_table = Int64HashTable(right_size)
713+
elif by_t is uint64_t:
714+
hash_table = UInt64HashTable(right_size)
710715

711716
right_pos = 0
712717
for left_pos in range(left_size):
@@ -718,19 +723,25 @@ def asof_join_backward_on_X_by_Y(numeric_t[:] left_values,
718723
if allow_exact_matches:
719724
while (right_pos < right_size and
720725
right_values[right_pos] <= left_values[left_pos]):
721-
hash_table.set_item(right_by_values[right_pos], right_pos)
726+
if use_hashtable:
727+
hash_table.set_item(right_by_values[right_pos], right_pos)
722728
right_pos += 1
723729
else:
724730
while (right_pos < right_size and
725731
right_values[right_pos] < left_values[left_pos]):
726-
hash_table.set_item(right_by_values[right_pos], right_pos)
732+
if use_hashtable:
733+
hash_table.set_item(right_by_values[right_pos], right_pos)
727734
right_pos += 1
728735
right_pos -= 1
729736

730737
# save positions as the desired index
731-
by_value = left_by_values[left_pos]
732-
found_right_pos = (hash_table.get_item(by_value)
733-
if by_value in hash_table else -1)
738+
if use_hashtable:
739+
by_value = left_by_values[left_pos]
740+
found_right_pos = (hash_table.get_item(by_value)
741+
if by_value in hash_table else -1)
742+
else:
743+
found_right_pos = right_pos
744+
734745
left_indexer[left_pos] = left_pos
735746
right_indexer[left_pos] = found_right_pos
736747

@@ -748,7 +759,8 @@ def asof_join_forward_on_X_by_Y(numeric_t[:] left_values,
748759
by_t[:] left_by_values,
749760
by_t[:] right_by_values,
750761
bint allow_exact_matches=1,
751-
tolerance=None):
762+
tolerance=None,
763+
bint use_hashtable=True):
752764

753765
cdef:
754766
Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
@@ -770,12 +782,13 @@ def asof_join_forward_on_X_by_Y(numeric_t[:] left_values,
770782
left_indexer = np.empty(left_size, dtype=np.intp)
771783
right_indexer = np.empty(left_size, dtype=np.intp)
772784

773-
if by_t is object:
774-
hash_table = PyObjectHashTable(right_size)
775-
elif by_t is int64_t:
776-
hash_table = Int64HashTable(right_size)
777-
elif by_t is uint64_t:
778-
hash_table = UInt64HashTable(right_size)
785+
if use_hashtable:
786+
if by_t is object:
787+
hash_table = PyObjectHashTable(right_size)
788+
elif by_t is int64_t:
789+
hash_table = Int64HashTable(right_size)
790+
elif by_t is uint64_t:
791+
hash_table = UInt64HashTable(right_size)
779792

780793
right_pos = right_size - 1
781794
for left_pos in range(left_size - 1, -1, -1):
@@ -787,19 +800,26 @@ def asof_join_forward_on_X_by_Y(numeric_t[:] left_values,
787800
if allow_exact_matches:
788801
while (right_pos >= 0 and
789802
right_values[right_pos] >= left_values[left_pos]):
790-
hash_table.set_item(right_by_values[right_pos], right_pos)
803+
if use_hashtable:
804+
hash_table.set_item(right_by_values[right_pos], right_pos)
791805
right_pos -= 1
792806
else:
793807
while (right_pos >= 0 and
794808
right_values[right_pos] > left_values[left_pos]):
795-
hash_table.set_item(right_by_values[right_pos], right_pos)
809+
if use_hashtable:
810+
hash_table.set_item(right_by_values[right_pos], right_pos)
796811
right_pos -= 1
797812
right_pos += 1
798813

799814
# save positions as the desired index
800-
by_value = left_by_values[left_pos]
801-
found_right_pos = (hash_table.get_item(by_value)
802-
if by_value in hash_table else -1)
815+
if use_hashtable:
816+
by_value = left_by_values[left_pos]
817+
found_right_pos = (hash_table.get_item(by_value)
818+
if by_value in hash_table else -1)
819+
else:
820+
found_right_pos = (right_pos
821+
if right_pos != right_size else -1)
822+
803823
left_indexer[left_pos] = left_pos
804824
right_indexer[left_pos] = found_right_pos
805825

@@ -820,15 +840,7 @@ def asof_join_nearest_on_X_by_Y(numeric_t[:] left_values,
820840
tolerance=None):
821841

822842
cdef:
823-
Py_ssize_t left_size, right_size, i
824-
ndarray[intp_t] left_indexer, right_indexer, bli, bri, fli, fri
825-
numeric_t bdiff, fdiff
826-
827-
left_size = len(left_values)
828-
right_size = len(right_values)
829-
830-
left_indexer = np.empty(left_size, dtype=np.intp)
831-
right_indexer = np.empty(left_size, dtype=np.intp)
843+
ndarray[intp_t] bli, bri, fli, fri
832844

833845
# search both forward and backward
834846
bli, bri = asof_join_backward_on_X_by_Y(
@@ -848,6 +860,27 @@ def asof_join_nearest_on_X_by_Y(numeric_t[:] left_values,
848860
tolerance,
849861
)
850862

863+
return _choose_smaller_timestamp(left_values, right_values, bli, bri, fli, fri)
864+
865+
866+
cdef _choose_smaller_timestamp(
867+
numeric_t[:] left_values,
868+
numeric_t[:] right_values,
869+
ndarray[intp_t] bli,
870+
ndarray[intp_t] bri,
871+
ndarray[intp_t] fli,
872+
ndarray[intp_t] fri,
873+
):
874+
cdef:
875+
ndarray[intp_t] left_indexer, right_indexer
876+
Py_ssize_t left_size, i
877+
numeric_t bdiff, fdiff
878+
879+
left_size = len(left_values)
880+
881+
left_indexer = np.empty(left_size, dtype=np.intp)
882+
right_indexer = np.empty(left_size, dtype=np.intp)
883+
851884
for i in range(len(bri)):
852885
# choose timestamp from right with smaller difference
853886
if bri[i] != -1 and fri[i] != -1:
@@ -870,106 +903,30 @@ def asof_join_backward(numeric_t[:] left_values,
870903
bint allow_exact_matches=True,
871904
tolerance=None):
872905

873-
cdef:
874-
Py_ssize_t left_pos, right_pos, left_size, right_size
875-
ndarray[intp_t] left_indexer, right_indexer
876-
bint has_tolerance = False
877-
numeric_t tolerance_ = 0
878-
numeric_t diff = 0
879-
880-
# if we are using tolerance, set our objects
881-
if tolerance is not None:
882-
has_tolerance = True
883-
tolerance_ = tolerance
884-
885-
left_size = len(left_values)
886-
right_size = len(right_values)
887-
888-
left_indexer = np.empty(left_size, dtype=np.intp)
889-
right_indexer = np.empty(left_size, dtype=np.intp)
890-
891-
right_pos = 0
892-
for left_pos in range(left_size):
893-
# restart right_pos if it went negative in a previous iteration
894-
if right_pos < 0:
895-
right_pos = 0
896-
897-
# find last position in right whose value is less than left's
898-
if allow_exact_matches:
899-
while (right_pos < right_size and
900-
right_values[right_pos] <= left_values[left_pos]):
901-
right_pos += 1
902-
else:
903-
while (right_pos < right_size and
904-
right_values[right_pos] < left_values[left_pos]):
905-
right_pos += 1
906-
right_pos -= 1
907-
908-
# save positions as the desired index
909-
left_indexer[left_pos] = left_pos
910-
right_indexer[left_pos] = right_pos
911-
912-
# if needed, verify that tolerance is met
913-
if has_tolerance and right_pos != -1:
914-
diff = left_values[left_pos] - right_values[right_pos]
915-
if diff > tolerance_:
916-
right_indexer[left_pos] = -1
917-
918-
return left_indexer, right_indexer
906+
return asof_join_backward_on_X_by_Y(
907+
left_values,
908+
right_values,
909+
None,
910+
None,
911+
allow_exact_matches=allow_exact_matches,
912+
tolerance=tolerance,
913+
use_hashtable=False,
914+
)
919915

920916

921917
def asof_join_forward(numeric_t[:] left_values,
922918
numeric_t[:] right_values,
923919
bint allow_exact_matches=True,
924920
tolerance=None):
925-
926-
cdef:
927-
Py_ssize_t left_pos, right_pos, left_size, right_size
928-
ndarray[intp_t] left_indexer, right_indexer
929-
bint has_tolerance = False
930-
numeric_t tolerance_ = 0
931-
numeric_t diff = 0
932-
933-
# if we are using tolerance, set our objects
934-
if tolerance is not None:
935-
has_tolerance = True
936-
tolerance_ = tolerance
937-
938-
left_size = len(left_values)
939-
right_size = len(right_values)
940-
941-
left_indexer = np.empty(left_size, dtype=np.intp)
942-
right_indexer = np.empty(left_size, dtype=np.intp)
943-
944-
right_pos = right_size - 1
945-
for left_pos in range(left_size - 1, -1, -1):
946-
# restart right_pos if it went over in a previous iteration
947-
if right_pos == right_size:
948-
right_pos = right_size - 1
949-
950-
# find first position in right whose value is greater than left's
951-
if allow_exact_matches:
952-
while (right_pos >= 0 and
953-
right_values[right_pos] >= left_values[left_pos]):
954-
right_pos -= 1
955-
else:
956-
while (right_pos >= 0 and
957-
right_values[right_pos] > left_values[left_pos]):
958-
right_pos -= 1
959-
right_pos += 1
960-
961-
# save positions as the desired index
962-
left_indexer[left_pos] = left_pos
963-
right_indexer[left_pos] = (right_pos
964-
if right_pos != right_size else -1)
965-
966-
# if needed, verify that tolerance is met
967-
if has_tolerance and right_pos != right_size:
968-
diff = right_values[right_pos] - left_values[left_pos]
969-
if diff > tolerance_:
970-
right_indexer[left_pos] = -1
971-
972-
return left_indexer, right_indexer
921+
return asof_join_forward_on_X_by_Y(
922+
left_values,
923+
right_values,
924+
None,
925+
None,
926+
allow_exact_matches=allow_exact_matches,
927+
tolerance=tolerance,
928+
use_hashtable=False,
929+
)
973930

974931

975932
def asof_join_nearest(numeric_t[:] left_values,
@@ -978,29 +935,12 @@ def asof_join_nearest(numeric_t[:] left_values,
978935
tolerance=None):
979936

980937
cdef:
981-
Py_ssize_t left_size, i
982-
ndarray[intp_t] left_indexer, right_indexer, bli, bri, fli, fri
983-
numeric_t bdiff, fdiff
984-
985-
left_size = len(left_values)
986-
987-
left_indexer = np.empty(left_size, dtype=np.intp)
988-
right_indexer = np.empty(left_size, dtype=np.intp)
938+
ndarray[intp_t] bli, bri, fli, fri
989939

990940
# search both forward and backward
991941
bli, bri = asof_join_backward(left_values, right_values,
992942
allow_exact_matches, tolerance)
993943
fli, fri = asof_join_forward(left_values, right_values,
994944
allow_exact_matches, tolerance)
995945

996-
for i in range(len(bri)):
997-
# choose timestamp from right with smaller difference
998-
if bri[i] != -1 and fri[i] != -1:
999-
bdiff = left_values[bli[i]] - right_values[bri[i]]
1000-
fdiff = right_values[fri[i]] - left_values[fli[i]]
1001-
right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
1002-
else:
1003-
right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
1004-
left_indexer[i] = bli[i]
1005-
1006-
return left_indexer, right_indexer
946+
return _choose_smaller_timestamp(left_values, right_values, bli, bri, fli, fri)

0 commit comments

Comments
 (0)