@@ -280,19 +280,28 @@ def _indicator_post_merge(self, result):
280
280
return result
281
281
282
282
def _maybe_add_join_keys (self , result , left_indexer , right_indexer ):
283
- # insert group keys
283
+
284
+ consolidate = False
285
+
286
+ left_has_missing = None
287
+ right_has_missing = None
284
288
285
289
keys = zip (self .join_names , self .left_on , self .right_on )
286
290
for i , (name , lname , rname ) in enumerate (keys ):
287
291
if not _should_fill (lname , rname ):
288
292
continue
289
293
294
+ take_left , take_right = None , None
295
+
290
296
if name in result :
297
+ < << << << HEAD
291
298
key_indexer = result .columns .get_loc (name )
299
+ == == == =
300
+ >> >> >> > e79b978 ... Preserve dtype in merge keys when possible
292
301
293
302
if left_indexer is not None and right_indexer is not None :
294
-
295
303
if name in self .left :
304
+ < << << << HEAD
296
305
if len (self .left ) == 0 :
297
306
continue
298
307
@@ -316,19 +325,71 @@ def _maybe_add_join_keys(self, result, left_indexer, right_indexer):
316
325
result .iloc [na_indexer , key_indexer ] = (
317
326
algos .take_1d (self .left_join_keys [i ],
318
327
left_na_indexer ))
328
+ == == == =
329
+
330
+ if left_has_missing is None :
331
+ left_has_missing = any (left_indexer == - 1 )
332
+
333
+ if left_has_missing :
334
+ take_right = self .right_join_keys [i ]
335
+
336
+ if result [name ].dtype != self .left [name ].dtype :
337
+ take_left = self .left [name ].values
338
+
339
+ elif name in self .right :
340
+
341
+ if right_has_missing is None :
342
+ right_has_missing = any (right_indexer == - 1 )
343
+
344
+ if right_has_missing :
345
+ take_left = self .left_join_keys [i ]
346
+
347
+ if result [name ].dtype != self .right [name ].dtype :
348
+ take_right = self .right [name ].values
349
+
350
+ > >> >> >> e79b978 ... Preserve dtype in merge keys when possible
319
351
elif left_indexer is not None \
320
352
and isinstance (self .left_join_keys [i ], np .ndarray ):
321
353
322
- if name is None :
323
- name = 'key_%d' % i
354
+ take_left = self .left_join_keys [i ]
355
+ take_right = self .right_join_keys [i ]
356
+
357
+ if take_left is not None or take_right is not None :
358
+
359
+ if take_left is None :
360
+ lvals = result [name ].values
361
+ else :
362
+ lfill = take_left .dtype .type ()
363
+ lvals = com .take_1d (take_left , left_indexer , fill_value = lfill )
364
+
365
+ if take_right is None :
366
+ rvals = result [name ].values
367
+ else :
368
+ rfill = take_right .dtype .type ()
369
+ rvals = com .take_1d (take_right , right_indexer , fill_value = rfill )
370
+
371
+ key_col = np .where (left_indexer != - 1 , lvals , rvals )
372
+
373
+ if name in result :
374
+ if result [name ].dtype != key_col .dtype :
375
+ consolidate = True
376
+ result [name ] = key_col
377
+ else :
378
+ result .insert (i , name or 'key_%d' % i , key_col )
379
+ consolidate = True
324
380
381
+ << << << < HEAD
325
382
# a faster way?
326
383
key_col = algos .take_1d (self .left_join_keys [i ], left_indexer )
327
384
na_indexer = (left_indexer == - 1 ).nonzero ()[0 ]
328
385
right_na_indexer = right_indexer .take (na_indexer )
329
386
key_col .put (na_indexer , algos .take_1d (self .right_join_keys [i ],
330
387
right_na_indexer ))
331
388
result .insert (i , name , key_col )
389
+ == == == =
390
+ if consolidate :
391
+ result .consolidate (inplace = True )
392
+ > >> >> >> e79b978 ... Preserve dtype in merge keys when possible
332
393
333
394
def _get_join_info (self ):
334
395
left_ax = self .left ._data .axes [self .axis ]
0 commit comments