Skip to content

Commit e42ecc2

Browse files
committed
ENH: left/right merge operations working and fairly fast, #249
1 parent fc4ca8d commit e42ecc2

File tree

2 files changed

+159
-23
lines changed

2 files changed

+159
-23
lines changed

pandas/tools/merge.py

+37-21
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pandas.core.frame import DataFrame
88
from pandas.core.index import Index
99
from pandas.core.internals import _JoinOperation
10+
import pandas.core.common as com
1011

1112
import pandas._tseries as lib
1213
from pandas._sandbox import Factorizer
@@ -108,11 +109,25 @@ def get_result(self):
108109

109110
new_axis = Index(np.arange(len(left_indexer)))
110111

112+
# TODO: more efficiently handle group keys to avoid extra consolidation!
113+
111114
join_op = _JoinOperation(ldata, rdata, new_axis,
112115
left_indexer, right_indexer, axis=1)
113116

114117
result_data = join_op.get_result(copy=self.copy)
115-
return DataFrame(result_data)
118+
result = DataFrame(result_data)
119+
120+
# insert group keys
121+
for i, name in enumerate(join_names):
122+
# a faster way?
123+
key_col = com.take_1d(left_join_keys[i], left_indexer)
124+
na_indexer = (left_indexer == -1).nonzero()[0]
125+
right_na_indexer = right_indexer.take(na_indexer)
126+
key_col.put(na_indexer, com.take_1d(right_join_keys[i],
127+
right_na_indexer))
128+
result.insert(i, name, key_col)
129+
130+
return result
116131

117132
def _get_merge_data(self, join_names):
118133
"""
@@ -148,8 +163,8 @@ def _get_merge_keys(self):
148163
right_keys = []
149164
join_names = []
150165

151-
need_set_names = False
152-
pop_right = False
166+
# need_set_names = False
167+
# pop_right = False
153168

154169
if (self.on is None and self.left_on is None
155170
and self.right_on is None):
@@ -158,7 +173,8 @@ def _get_merge_keys(self):
158173
left_keys.append(self.left.index.values)
159174
right_keys.append(self.right.index.values)
160175

161-
need_set_names = True
176+
# need_set_names = True
177+
162178
# XXX something better than this
163179
join_names.append('join_key')
164180
elif self.left_index:
@@ -173,30 +189,30 @@ def _get_merge_keys(self):
173189
# use the common columns
174190
common_cols = self.left.columns.intersection(self.right.columns)
175191
self.left_on = self.right_on = common_cols
176-
pop_right = True
192+
193+
# pop_right = True
194+
177195
elif self.on is not None:
178196
if self.left_on is not None or self.right_on is not None:
179197
raise Exception('Can only pass on OR left_on and '
180198
'right_on')
181199
self.left_on = self.right_on = self.on
182-
pop_right = True
183200

184-
if self.right_on is not None:
185-
# this is a touch kludgy, but accomplishes the goal
186-
if pop_right:
187-
right = self.right.copy()
188-
right_keys.extend([right.pop(k) for k in self.right_on])
189-
self.right = right
190-
else:
191-
right_keys.extend([right[k] for k in self.right_on])
201+
# pop_right = True
192202

193-
if need_set_names:
194-
self.left = self.left.copy()
195-
for i, (lkey, name) in enumerate(zip(left_keys, join_names)):
196-
self.left.insert(i, name, lkey)
203+
# this is a touch kludgy, but accomplishes the goal
204+
if self.right_on is not None:
205+
right = self.right.copy()
206+
right_keys.extend([right.pop(k) for k in self.right_on])
207+
self.right = right
197208

198209
if self.left_on is not None:
199-
left_keys.extend([self.left[k] for k in self.left_on])
210+
left = self.left.copy()
211+
left_keys.extend([left.pop(k) for k in self.left_on])
212+
self.left = left
213+
214+
# TODO: something else?
215+
join_names = self.left_on
200216

201217
return left_keys, right_keys, join_names
202218

@@ -253,8 +269,8 @@ def _maybe_make_list(obj):
253269
return [obj]
254270
return obj
255271

256-
def _right_outer_join(x, y):
257-
right_indexer, left_indexer = sbx.left_outer_join(y, x)
272+
def _right_outer_join(x, y, max_groups):
273+
right_indexer, left_indexer = sbx.left_outer_join(y, x, max_groups)
258274
return left_indexer, right_indexer
259275

260276
_join_functions = {

pandas/tools/tests/test_merge.py

+122-2
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,44 @@
22
import unittest
33

44
import numpy as np
5+
import random
56

7+
from pandas import *
68
from pandas.tools.merge import merge
79
import pandas._sandbox as sbx
810

911
a_ = np.array
1012

13+
N = 100
14+
NGROUPS = 8
15+
16+
def get_test_data(ngroups=NGROUPS, n=N):
17+
unique_groups = range(ngroups)
18+
arr = np.asarray(np.tile(unique_groups, n / ngroups), dtype=object)
19+
20+
if len(arr) < n:
21+
arr = np.asarray(list(arr) + unique_groups[:n - len(arr)],
22+
dtype=object)
23+
24+
random.shuffle(arr)
25+
return arr
26+
1127
class TestMerge(unittest.TestCase):
1228

1329
def setUp(self):
14-
pass
30+
# aggregate multiple columns
31+
self.df = DataFrame({'key1' : get_test_data(),
32+
'key2' : get_test_data(),
33+
'data1' : np.random.randn(N),
34+
'data2' : np.random.randn(N)})
35+
36+
# exclude a couple keys for fun
37+
self.df = self.df[self.df['key2'] > 1]
38+
39+
self.df2 = DataFrame({'key1' : get_test_data(n=N//5),
40+
'key2' : get_test_data(ngroups=NGROUPS//2,
41+
n=N//5),
42+
'value' : np.random.randn(N // 5)})
1543

1644
def test_cython_left_outer_join(self):
1745
left = a_([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype='i4')
@@ -92,7 +120,31 @@ def test_cython_inner_join(self):
92120
def test_cython_full_outer_join(self):
93121
pass
94122

95-
def test_left_join(self):
123+
def test_left_outer_join(self):
124+
joined_key2 = merge(self.df, self.df2, on='key2')
125+
_check_join(self.df, self.df2, joined_key2, ['key2'], how='left')
126+
127+
joined_both = merge(self.df, self.df2)
128+
_check_join(self.df, self.df2, joined_both, ['key1', 'key2'],
129+
how='left')
130+
131+
def test_right_outer_join(self):
132+
joined_key2 = merge(self.df, self.df2, on='key2', how='right')
133+
_check_join(self.df, self.df2, joined_key2, ['key2'], how='right')
134+
135+
joined_both = merge(self.df, self.df2, how='right')
136+
_check_join(self.df, self.df2, joined_both, ['key1', 'key2'],
137+
how='right')
138+
139+
# def test_full_outer_join(self):
140+
# joined_key2 = merge(self.df, self.df2, on='key2', how='outer')
141+
# _check_join(self.df, self.df2, joined_key2, ['key2'], how='outer')
142+
143+
# joined_both = merge(self.df, self.df2, how='outer')
144+
# _check_join(self.df, self.df2, joined_both, ['key1', 'key2'],
145+
# how='outer')
146+
147+
def test_handle_overlap(self):
96148
pass
97149

98150
def test_merge_common(self):
@@ -101,6 +153,74 @@ def test_merge_common(self):
101153
def test_merge_index(self):
102154
pass
103155

156+
def _check_join(left, right, result, join_col, how='left',
157+
lsuffix='.x', rsuffix='.y'):
158+
159+
# some smoke tests
160+
for c in join_col:
161+
assert(result[c].notnull().all())
162+
163+
left_grouped = left.groupby(join_col)
164+
right_grouped = right.groupby(join_col)
165+
166+
for group_key, group in result.groupby(join_col):
167+
l_joined = _restrict_to_columns(group, left.columns, lsuffix)
168+
r_joined = _restrict_to_columns(group, right.columns, rsuffix)
169+
170+
try:
171+
lgroup = left_grouped.get_group(group_key)
172+
except KeyError:
173+
if how == 'left':
174+
raise AssertionError('key %s should not have been in the join'
175+
% str(group_key))
176+
177+
_assert_all_na(l_joined, left.columns, join_col)
178+
else:
179+
_assert_same_contents(l_joined, lgroup)
180+
181+
try:
182+
rgroup = right_grouped.get_group(group_key)
183+
except KeyError:
184+
if how == 'right':
185+
raise AssertionError('key %s should not have been in the join'
186+
% str(group_key))
187+
188+
_assert_all_na(r_joined, right.columns, join_col)
189+
else:
190+
_assert_same_contents(r_joined, rgroup)
191+
192+
193+
def _restrict_to_columns(group, columns, suffix):
194+
found = [c for c in group.columns
195+
if c in columns or c.replace(suffix, '') in columns]
196+
197+
# filter
198+
group = group.ix[:, found]
199+
200+
# get rid of suffixes, if any
201+
group = group.rename(columns=lambda x: x.replace(suffix, ''))
202+
203+
# put in the right order...
204+
group = group.ix[:, columns]
205+
206+
return group
207+
208+
def _assert_same_contents(join_chunk, source):
209+
NA_SENTINEL = -1234567 # drop_duplicates not so NA-friendly...
210+
211+
jvalues = join_chunk.fillna(NA_SENTINEL).drop_duplicates().values
212+
svalues = source.fillna(NA_SENTINEL).drop_duplicates().values
213+
214+
rows = set(tuple(row) for row in jvalues)
215+
assert(len(rows) == len(source))
216+
assert(all(tuple(row) in rows for row in svalues))
217+
218+
def _assert_all_na(join_chunk, source_columns, join_col):
219+
for c in source_columns:
220+
if c in join_col:
221+
continue
222+
assert(join_chunk[c].isnull().all())
223+
104224
if __name__ == '__main__':
105225
import nose
106226
nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],

0 commit comments

Comments
 (0)