Skip to content

Commit 4bfdde4

Browse files
committed
BUG: fix several many-to-one join bugs
1 parent 16d2cdd commit 4bfdde4

File tree

3 files changed

+343
-200
lines changed

3 files changed

+343
-200
lines changed

pandas/src/generate_code.py

+62-40
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,10 @@ def left_join_indexer_unique_%(name)s(ndarray[%(c_type)s] left,
660660
661661
"""
662662

663-
left_join_template = """@cython.wraparound(False)
664-
@cython.boundscheck(False)
663+
# @cython.wraparound(False)
664+
# @cython.boundscheck(False)
665+
666+
left_join_template = """
665667
def left_join_indexer_%(name)s(ndarray[%(c_type)s] left,
666668
ndarray[%(c_type)s] right):
667669
'''
@@ -691,9 +693,12 @@ def left_join_indexer_%(name)s(ndarray[%(c_type)s] left,
691693
if lval == rval:
692694
count += 1
693695
if i < nleft - 1:
694-
i += 1
695-
if left[i] != rval:
696+
if j < nright - 1 and right[j + 1] == rval:
696697
j += 1
698+
else:
699+
i += 1
700+
if left[i] != rval:
701+
j += 1
697702
elif j < nright - 1:
698703
j += 1
699704
if lval != right[j]:
@@ -725,6 +730,7 @@ def left_join_indexer_%(name)s(ndarray[%(c_type)s] left,
725730
result[count] = left[i]
726731
i += 1
727732
count += 1
733+
break
728734
729735
lval = left[i]
730736
rval = right[j]
@@ -735,9 +741,12 @@ def left_join_indexer_%(name)s(ndarray[%(c_type)s] left,
735741
result[count] = lval
736742
count += 1
737743
if i < nleft - 1:
738-
i += 1
739-
if left[i] != rval:
744+
if j < nright - 1 and right[j + 1] == rval:
740745
j += 1
746+
else:
747+
i += 1
748+
if left[i] != rval:
749+
j += 1
741750
elif j < nright - 1:
742751
j += 1
743752
if lval != right[j]:
@@ -779,31 +788,34 @@ def inner_join_indexer_%(name)s(ndarray[%(c_type)s] left,
779788
j = 0
780789
count = 0
781790
if nleft > 0 and nright > 0:
782-
lval = left[0]
783-
rval = right[0]
784791
while True:
792+
if i == nleft:
793+
break
794+
if j == nright:
795+
break
796+
797+
lval = left[i]
798+
rval = right[j]
785799
if lval == rval:
786800
count += 1
787801
if i < nleft - 1:
788-
i += 1
789-
lval = left[i]
802+
if j < nright - 1 and right[j + 1] == rval:
803+
j += 1
804+
else:
805+
i += 1
806+
if left[i] != rval:
807+
j += 1
790808
elif j < nright - 1:
791809
j += 1
792-
rval = right[j]
810+
if lval != right[j]:
811+
i += 1
793812
else:
813+
# end of the road
794814
break
795815
elif lval < rval:
796-
if i < nleft - 1:
797-
i += 1
798-
lval = left[i]
799-
else:
800-
break
816+
i += 1
801817
else:
802-
if j < nright - 1:
803-
j += 1
804-
rval = right[j]
805-
else:
806-
break
818+
j += 1
807819
808820
# do it again now that result size is known
809821
@@ -815,34 +827,37 @@ def inner_join_indexer_%(name)s(ndarray[%(c_type)s] left,
815827
j = 0
816828
count = 0
817829
if nleft > 0 and nright > 0:
818-
lval = left[0]
819-
rval = right[0]
820830
while True:
831+
if i == nleft:
832+
break
833+
if j == nright:
834+
break
835+
836+
lval = left[i]
837+
rval = right[j]
821838
if lval == rval:
822839
lindexer[count] = i
823840
rindexer[count] = j
824841
result[count] = rval
825842
count += 1
826843
if i < nleft - 1:
827-
i += 1
828-
lval = left[i]
844+
if j < nright - 1 and right[j + 1] == rval:
845+
j += 1
846+
else:
847+
i += 1
848+
if left[i] != rval:
849+
j += 1
829850
elif j < nright - 1:
830851
j += 1
831-
rval = right[j]
852+
if lval != right[j]:
853+
i += 1
832854
else:
855+
# end of the road
833856
break
834857
elif lval < rval:
835-
if i < nleft - 1:
836-
i += 1
837-
lval = left[i]
838-
else:
839-
break
858+
i += 1
840859
else:
841-
if j < nright - 1:
842-
j += 1
843-
rval = right[j]
844-
else:
845-
break
860+
j += 1
846861
847862
return result, lindexer, rindexer
848863
@@ -883,9 +898,12 @@ def outer_join_indexer_%(name)s(ndarray[%(c_type)s] left,
883898
if lval == rval:
884899
count += 1
885900
if i < nleft - 1:
886-
i += 1
887-
if left[i] != rval:
901+
if j < nright - 1 and right[j + 1] == rval:
888902
j += 1
903+
else:
904+
i += 1
905+
if left[i] != rval:
906+
j += 1
889907
elif j < nright - 1:
890908
j += 1
891909
if lval != right[j]:
@@ -947,14 +965,18 @@ def outer_join_indexer_%(name)s(ndarray[%(c_type)s] left,
947965
result[count] = lval
948966
count += 1
949967
if i < nleft - 1:
950-
i += 1
951-
if left[i] != rval:
968+
if j < nright - 1 and right[j + 1] == rval:
952969
j += 1
970+
else:
971+
i += 1
972+
if left[i] != rval:
973+
j += 1
953974
elif j < nright - 1:
954975
j += 1
955976
if lval != right[j]:
956977
i += 1
957978
else:
979+
# end of the road
958980
break
959981
elif lval < rval:
960982
lindexer[count] = i

0 commit comments

Comments
 (0)