2
2
import unittest
3
3
4
4
import numpy as np
5
+ import random
5
6
7
+ from pandas import *
6
8
from pandas .tools .merge import merge
7
9
import pandas ._sandbox as sbx
8
10
9
11
a_ = np .array
10
12
13
+ N = 100
14
+ NGROUPS = 8
15
+
16
+ def get_test_data (ngroups = NGROUPS , n = N ):
17
+ unique_groups = range (ngroups )
18
+ arr = np .asarray (np .tile (unique_groups , n / ngroups ), dtype = object )
19
+
20
+ if len (arr ) < n :
21
+ arr = np .asarray (list (arr ) + unique_groups [:n - len (arr )],
22
+ dtype = object )
23
+
24
+ random .shuffle (arr )
25
+ return arr
26
+
11
27
class TestMerge (unittest .TestCase ):
12
28
13
29
def setUp (self ):
14
- pass
30
+ # aggregate multiple columns
31
+ self .df = DataFrame ({'key1' : get_test_data (),
32
+ 'key2' : get_test_data (),
33
+ 'data1' : np .random .randn (N ),
34
+ 'data2' : np .random .randn (N )})
35
+
36
+ # exclude a couple keys for fun
37
+ self .df = self .df [self .df ['key2' ] > 1 ]
38
+
39
+ self .df2 = DataFrame ({'key1' : get_test_data (n = N // 5 ),
40
+ 'key2' : get_test_data (ngroups = NGROUPS // 2 ,
41
+ n = N // 5 ),
42
+ 'value' : np .random .randn (N // 5 )})
15
43
16
44
def test_cython_left_outer_join (self ):
17
45
left = a_ ([0 , 1 , 2 , 1 , 2 , 0 , 0 , 1 , 2 , 3 , 3 ], dtype = 'i4' )
@@ -92,7 +120,31 @@ def test_cython_inner_join(self):
92
120
def test_cython_full_outer_join (self ):
93
121
pass
94
122
95
- def test_left_join (self ):
123
+ def test_left_outer_join (self ):
124
+ joined_key2 = merge (self .df , self .df2 , on = 'key2' )
125
+ _check_join (self .df , self .df2 , joined_key2 , ['key2' ], how = 'left' )
126
+
127
+ joined_both = merge (self .df , self .df2 )
128
+ _check_join (self .df , self .df2 , joined_both , ['key1' , 'key2' ],
129
+ how = 'left' )
130
+
131
+ def test_right_outer_join (self ):
132
+ joined_key2 = merge (self .df , self .df2 , on = 'key2' , how = 'right' )
133
+ _check_join (self .df , self .df2 , joined_key2 , ['key2' ], how = 'right' )
134
+
135
+ joined_both = merge (self .df , self .df2 , how = 'right' )
136
+ _check_join (self .df , self .df2 , joined_both , ['key1' , 'key2' ],
137
+ how = 'right' )
138
+
139
+ # def test_full_outer_join(self):
140
+ # joined_key2 = merge(self.df, self.df2, on='key2', how='outer')
141
+ # _check_join(self.df, self.df2, joined_key2, ['key2'], how='outer')
142
+
143
+ # joined_both = merge(self.df, self.df2, how='outer')
144
+ # _check_join(self.df, self.df2, joined_both, ['key1', 'key2'],
145
+ # how='outer')
146
+
147
+ def test_handle_overlap (self ):
96
148
pass
97
149
98
150
def test_merge_common (self ):
@@ -101,6 +153,74 @@ def test_merge_common(self):
101
153
def test_merge_index (self ):
102
154
pass
103
155
156
+ def _check_join (left , right , result , join_col , how = 'left' ,
157
+ lsuffix = '.x' , rsuffix = '.y' ):
158
+
159
+ # some smoke tests
160
+ for c in join_col :
161
+ assert (result [c ].notnull ().all ())
162
+
163
+ left_grouped = left .groupby (join_col )
164
+ right_grouped = right .groupby (join_col )
165
+
166
+ for group_key , group in result .groupby (join_col ):
167
+ l_joined = _restrict_to_columns (group , left .columns , lsuffix )
168
+ r_joined = _restrict_to_columns (group , right .columns , rsuffix )
169
+
170
+ try :
171
+ lgroup = left_grouped .get_group (group_key )
172
+ except KeyError :
173
+ if how == 'left' :
174
+ raise AssertionError ('key %s should not have been in the join'
175
+ % str (group_key ))
176
+
177
+ _assert_all_na (l_joined , left .columns , join_col )
178
+ else :
179
+ _assert_same_contents (l_joined , lgroup )
180
+
181
+ try :
182
+ rgroup = right_grouped .get_group (group_key )
183
+ except KeyError :
184
+ if how == 'right' :
185
+ raise AssertionError ('key %s should not have been in the join'
186
+ % str (group_key ))
187
+
188
+ _assert_all_na (r_joined , right .columns , join_col )
189
+ else :
190
+ _assert_same_contents (r_joined , rgroup )
191
+
192
+
193
+ def _restrict_to_columns (group , columns , suffix ):
194
+ found = [c for c in group .columns
195
+ if c in columns or c .replace (suffix , '' ) in columns ]
196
+
197
+ # filter
198
+ group = group .ix [:, found ]
199
+
200
+ # get rid of suffixes, if any
201
+ group = group .rename (columns = lambda x : x .replace (suffix , '' ))
202
+
203
+ # put in the right order...
204
+ group = group .ix [:, columns ]
205
+
206
+ return group
207
+
208
+ def _assert_same_contents (join_chunk , source ):
209
+ NA_SENTINEL = - 1234567 # drop_duplicates not so NA-friendly...
210
+
211
+ jvalues = join_chunk .fillna (NA_SENTINEL ).drop_duplicates ().values
212
+ svalues = source .fillna (NA_SENTINEL ).drop_duplicates ().values
213
+
214
+ rows = set (tuple (row ) for row in jvalues )
215
+ assert (len (rows ) == len (source ))
216
+ assert (all (tuple (row ) in rows for row in svalues ))
217
+
218
+ def _assert_all_na (join_chunk , source_columns , join_col ):
219
+ for c in source_columns :
220
+ if c in join_col :
221
+ continue
222
+ assert (join_chunk [c ].isnull ().all ())
223
+
104
224
if __name__ == '__main__' :
105
225
import nose
106
226
nose .runmodule (argv = [__file__ ,'-vvs' ,'-x' ,'--pdb' , '--pdb-failure' ],
0 commit comments