@@ -56,12 +56,8 @@ def __init__(self, left, right, how='inner', on=None,
56
56
self .left_on = com ._maybe_make_list (left_on )
57
57
self .right_on = com ._maybe_make_list (right_on )
58
58
59
- self .drop_keys = False # set this later...kludge
60
-
61
59
self .copy = copy
62
-
63
60
self .suffixes = suffixes
64
-
65
61
self .sort = sort
66
62
67
63
self .left_index = left_index
@@ -91,26 +87,33 @@ def get_result(self):
91
87
return result
92
88
93
89
def _maybe_add_join_keys (self , result , left_indexer , right_indexer ):
94
- if not self .drop_keys :
95
- # do nothing, already found in one of the DataFrames
96
- return
97
-
98
90
# insert group keys
99
- for i , name in enumerate (self .join_names ):
91
+
92
+ keys = zip (self .join_names , self .left_on , self .right_on )
93
+ for i , (name , lname , rname ) in enumerate (keys ):
94
+ if not _should_fill (lname , rname ):
95
+ continue
96
+
100
97
if name in result :
101
98
key_col = result [name ]
102
99
103
- if name in self .left :
100
+ if name in self .left and left_indexer is not None :
104
101
na_indexer = (left_indexer == - 1 ).nonzero ()[0 ]
102
+ if len (na_indexer ) == 0 :
103
+ continue
104
+
105
105
right_na_indexer = right_indexer .take (na_indexer )
106
106
key_col .put (na_indexer , com .take_1d (self .right_join_keys [i ],
107
107
right_na_indexer ))
108
- else :
108
+ elif name in self . right and right_indexer is not None :
109
109
na_indexer = (right_indexer == - 1 ).nonzero ()[0 ]
110
- left_na_indexer = right_indexer .take (na_indexer )
110
+ if len (na_indexer ) == 0 :
111
+ continue
112
+
113
+ left_na_indexer = left_indexer .take (na_indexer )
111
114
key_col .put (na_indexer , com .take_1d (self .left_join_keys [i ],
112
115
left_na_indexer ))
113
- else :
116
+ elif left_indexer is not None :
114
117
# a faster way?
115
118
key_col = com .take_1d (self .left_join_keys [i ], left_indexer )
116
119
na_indexer = (left_indexer == - 1 ).nonzero ()[0 ]
@@ -181,30 +184,41 @@ def _get_merge_keys(self):
181
184
and self .right_on is None ):
182
185
183
186
if self .left_index and self .right_index :
184
- pass
187
+ self . left_on , self . right_on = (), ()
185
188
elif self .left_index :
186
189
if self .right_on is None :
187
190
raise Exception ('Must pass right_on or right_index=True' )
191
+ self .left_on = [None ] * self .left .index .nlevels
188
192
elif self .right_index :
189
193
if self .left_on is None :
190
194
raise Exception ('Must pass left_on or left_index=True' )
195
+ self .right_on = [None ] * self .right .index .nlevels
191
196
else :
192
197
# use the common columns
193
198
common_cols = self .left .columns .intersection (self .right .columns )
194
199
self .left_on = self .right_on = common_cols
195
- self .drop_keys = True
196
-
197
200
elif self .on is not None :
198
201
if self .left_on is not None or self .right_on is not None :
199
202
raise Exception ('Can only pass on OR left_on and '
200
203
'right_on' )
201
204
self .left_on = self .right_on = self .on
202
- self .drop_keys = True
205
+ elif self .left_on is not None :
206
+ n = len (self .left_on )
207
+ if self .right_index :
208
+ self .right_on = [None ] * n
209
+ else :
210
+ assert (len (self .right_on ) == n )
211
+ elif self .right_on is not None :
212
+ n = len (self .right_on )
213
+ if self .left_index :
214
+ self .left_on = [None ] * n
215
+ else :
216
+ assert (len (self .left_on ) == n )
203
217
204
218
left_keys = []
205
219
right_keys = []
206
220
join_names = []
207
- left_drop , right_drop = [], []
221
+ right_drop = []
208
222
left , right = self .left , self .right
209
223
210
224
is_lkey = lambda x : isinstance (x , np .ndarray ) and len (x ) == len (left )
@@ -249,8 +263,6 @@ def _get_merge_keys(self):
249
263
250
264
if right_drop :
251
265
self .right = self .right .drop (right_drop , axis = 1 )
252
- if left_drop :
253
- self .left = self .left .drop (left_drop , axis = 1 )
254
266
255
267
return left_keys , right_keys , join_names
256
268
@@ -1006,6 +1018,11 @@ def _consensus_name_attr(objs):
1006
1018
return None
1007
1019
return name
1008
1020
1021
+ def _should_fill (lname , rname ):
1022
+ if not isinstance (lname , basestring ) or not isinstance (rname , basestring ):
1023
+ return True
1024
+ return lname == rname
1025
+
1009
1026
def _all_indexes_same (indexes ):
1010
1027
first = indexes [0 ]
1011
1028
for index in indexes [1 :]:
@@ -1014,4 +1031,4 @@ def _all_indexes_same(indexes):
1014
1031
return True
1015
1032
1016
1033
def _any (x ):
1017
- return x is not None and len (x ) > 0
1034
+ return x is not None and len (x ) > 0 and any ([ y is not None for y in x ])
0 commit comments