Skip to content

Commit 4173dbf

Browse files
Mike Kellyjreback
Mike Kelly
authored andcommitted
Preserve dtype in merge keys when possible
1 parent 0f1666d commit 4173dbf

File tree

2 files changed

+89
-6
lines changed

2 files changed

+89
-6
lines changed

pandas/tools/merge.py

+65-4
Original file line numberDiff line numberDiff line change
@@ -280,19 +280,28 @@ def _indicator_post_merge(self, result):
280280
return result
281281

282282
def _maybe_add_join_keys(self, result, left_indexer, right_indexer):
283-
# insert group keys
283+
284+
consolidate = False
285+
286+
left_has_missing = None
287+
right_has_missing = None
284288

285289
keys = zip(self.join_names, self.left_on, self.right_on)
286290
for i, (name, lname, rname) in enumerate(keys):
287291
if not _should_fill(lname, rname):
288292
continue
289293

294+
take_left, take_right = None, None
295+
290296
if name in result:
297+
<<<<<<< HEAD
291298
key_indexer = result.columns.get_loc(name)
299+
=======
300+
>>>>>>> e79b978... Preserve dtype in merge keys when possible
292301

293302
if left_indexer is not None and right_indexer is not None:
294-
295303
if name in self.left:
304+
<<<<<<< HEAD
296305
if len(self.left) == 0:
297306
continue
298307

@@ -316,19 +325,71 @@ def _maybe_add_join_keys(self, result, left_indexer, right_indexer):
316325
result.iloc[na_indexer, key_indexer] = (
317326
algos.take_1d(self.left_join_keys[i],
318327
left_na_indexer))
328+
=======
329+
330+
if left_has_missing is None:
331+
left_has_missing = any(left_indexer == -1)
332+
333+
if left_has_missing:
334+
take_right = self.right_join_keys[i]
335+
336+
if result[name].dtype != self.left[name].dtype:
337+
take_left = self.left[name].values
338+
339+
elif name in self.right:
340+
341+
if right_has_missing is None:
342+
right_has_missing = any(right_indexer == -1)
343+
344+
if right_has_missing:
345+
take_left = self.left_join_keys[i]
346+
347+
if result[name].dtype != self.right[name].dtype:
348+
take_right = self.right[name].values
349+
350+
>>>>>>> e79b978... Preserve dtype in merge keys when possible
319351
elif left_indexer is not None \
320352
and isinstance(self.left_join_keys[i], np.ndarray):
321353

322-
if name is None:
323-
name = 'key_%d' % i
354+
take_left = self.left_join_keys[i]
355+
take_right = self.right_join_keys[i]
356+
357+
if take_left is not None or take_right is not None:
358+
359+
if take_left is None:
360+
lvals = result[name].values
361+
else:
362+
lfill = take_left.dtype.type()
363+
lvals = com.take_1d(take_left, left_indexer, fill_value=lfill)
364+
365+
if take_right is None:
366+
rvals = result[name].values
367+
else:
368+
rfill = take_right.dtype.type()
369+
rvals = com.take_1d(take_right, right_indexer, fill_value=rfill)
370+
371+
key_col = np.where(left_indexer != -1, lvals, rvals)
372+
373+
if name in result:
374+
if result[name].dtype != key_col.dtype:
375+
consolidate = True
376+
result[name] = key_col
377+
else:
378+
result.insert(i, name or 'key_%d' % i, key_col)
379+
consolidate = True
324380

381+
<<<<<<< HEAD
325382
# a faster way?
326383
key_col = algos.take_1d(self.left_join_keys[i], left_indexer)
327384
na_indexer = (left_indexer == -1).nonzero()[0]
328385
right_na_indexer = right_indexer.take(na_indexer)
329386
key_col.put(na_indexer, algos.take_1d(self.right_join_keys[i],
330387
right_na_indexer))
331388
result.insert(i, name, key_col)
389+
=======
390+
if consolidate:
391+
result.consolidate(inplace=True)
392+
>>>>>>> e79b978... Preserve dtype in merge keys when possible
332393

333394
def _get_join_info(self):
334395
left_ax = self.left._data.axes[self.axis]

pandas/tools/tests/test_merge.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,8 @@ def test_join_many_non_unique_index(self):
507507

508508
result = result.reset_index()
509509

510-
result['a'] = result['a'].astype(np.float64)
511-
result['b'] = result['b'].astype(np.float64)
510+
# result['a'] = result['a'].astype(np.float64)
511+
# result['b'] = result['b'].astype(np.float64)
512512

513513
assert_frame_equal(result, expected.ix[:, result.columns])
514514

@@ -1033,6 +1033,7 @@ def test_overlapping_columns_error_message(self):
10331033
df2.columns = ['key1', 'foo', 'foo']
10341034
self.assertRaises(ValueError, merge, df, df2)
10351035

1036+
<<<<<<< HEAD
10361037
def test_merge_on_datetime64tz(self):
10371038

10381039
# GH11405
@@ -1426,6 +1427,27 @@ def test_indicator(self):
14261427
test5 = df3.merge(df4, on=['col1', 'col2'],
14271428
how='outer', indicator=True)
14281429
assert_frame_equal(test5, hand_coded_result)
1430+
=======
1431+
def test_merge_join_key_dtype_cast(self):
1432+
# #8596
1433+
1434+
df1 = DataFrame({'key': [1], 'v1': [10]})
1435+
df2 = DataFrame({'key': [2], 'v1': [20]})
1436+
df = merge(df1, df2, how='outer')
1437+
self.assertEqual(df['key'].dtype, 'int64')
1438+
1439+
df1 = DataFrame({'key': [True], 'v1': [1]})
1440+
df2 = DataFrame({'key': [False],'v1': [0]})
1441+
df = merge(df1, df2, how='outer')
1442+
self.assertEqual(df['key'].dtype, 'bool')
1443+
1444+
df1 = DataFrame({'val': [1]})
1445+
df2 = DataFrame({'val': [2]})
1446+
lkey = np.array([1])
1447+
rkey = np.array([2])
1448+
df = merge(df1, df2, left_on=lkey, right_on=rkey, how='outer')
1449+
self.assertEqual(df['key_0'].dtype, 'int64')
1450+
>>>>>>> e79b978... Preserve dtype in merge keys when possible
14291451

14301452

14311453
def _check_merge(x, y):

0 commit comments

Comments
 (0)