From eee555ac49bd69698c8b91c94ea494f76bc5998f Mon Sep 17 00:00:00 2001 From: jreback Date: Fri, 14 Feb 2014 11:38:40 -0500 Subject: [PATCH] ENH/BUG: allow single versus multi-index joining on inferred level (GH3662) --- pandas/core/index.py | 44 +++++++++++++++++++++-- pandas/tools/tests/test_merge.py | 61 +++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/pandas/core/index.py b/pandas/core/index.py index 316e82c05ef30..9fea5ede8dcc7 100644 --- a/pandas/core/index.py +++ b/pandas/core/index.py @@ -1265,8 +1265,48 @@ def join(self, other, how='left', level=None, return_indexers=False): ------- join_index, (left_indexer, right_indexer) """ - if (level is not None and (isinstance(self, MultiIndex) or - isinstance(other, MultiIndex))): + self_is_mi = isinstance(self, MultiIndex) + other_is_mi = isinstance(other, MultiIndex) + + # try to figure out the join level + # GH3662 + if (level is None and (self_is_mi or other_is_mi)): + + # have the same levels/names so a simple join + if self.names == other.names: + pass + + else: + + # figure out join names + self_names = [ n for n in self.names if n is not None ] + other_names = [ n for n in other.names if n is not None ] + overlap = list(set(self_names) & set(other_names)) + + # need at least 1 in common + if not len(overlap): + raise ValueError("cannot join with no level specified and no overlapping names") + + if self_is_mi and other_is_mi: + raise ValueError("cannot join between multiple multi-indexes") + + # make the indices into mi's that match + if self_is_mi: + level = self.names.index(overlap[0]) + result = other._join_level(self, level, how=how, + return_indexers=return_indexers) + + # reversed the results (as we reversed the inputs) + if isinstance(result, tuple): + return result[0], result[2], result[1] + + else: + level = other.names.index(overlap[0]) + return self._join_level(other, level, how=how, + return_indexers=return_indexers) + + # join on the level + if (level is not None and (self_is_mi or other_is_mi)): return self._join_level(other, level, how=how, return_indexers=return_indexers) diff --git a/pandas/tools/tests/test_merge.py b/pandas/tools/tests/test_merge.py index bfa6fd77ba733..28793b7e849f9 100644 --- a/pandas/tools/tests/test_merge.py +++ b/pandas/tools/tests/test_merge.py @@ -8,7 +8,7 @@ import numpy as np import random -from pandas.compat import range, lrange, lzip, zip +from pandas.compat import range, lrange, lzip, zip, StringIO from pandas import compat, _np_version_under1p7 from pandas.tseries.index import DatetimeIndex from pandas.tools.merge import merge, concat, ordered_merge, MergeError @@ -1025,6 +1025,65 @@ def test_int64_overflow_issues(self): result = merge(df1, df2, how='outer') self.assertTrue(len(result) == 2000) + def test_join_multi_levels(self): + + # GH 3662 + # merge multi-levels + from pandas import read_table + household = read_table( + StringIO( +"""household_id,male,wealth +1,0,196087.3 +2,1,316478.7 +3,0,294750 +""" + ), + sep=',', index_col='household_id' + ) + + portfolio = read_table( + StringIO( +""""household_id","asset_id","name","share" +"1","nl0000301109","ABN Amro","1.0" +"2","nl0000289783","Robeco","0.4" +"2","gb00b03mlx29","Royal Dutch Shell","0.6" +"3","gb00b03mlx29","Royal Dutch Shell","0.15" +"3","lu0197800237","AAB Eastern Europe Equity Fund","0.6" +"3","nl0000289965","Postbank BioTech Fonds","0.25" +"4",,,"1.0" +""" + ), + sep=',', index_col=['household_id', 'asset_id'] + ) + + result = household.join(portfolio, how='inner') + expected = DataFrame(dict(male = [0,1,1,0,0,0], + wealth = [ 196087.3, 316478.7, 316478.7, 294750.0, 294750.0, 294750.0 ], + name = ['ABN Amro','Robeco','Royal Dutch Shell','Royal Dutch Shell','AAB Eastern Europe Equity Fund','Postbank BioTech Fonds'], + share = [1.00,0.40,0.60,0.15,0.60,0.25], + household_id = [1,2,2,3,3,3], + asset_id = ['nl0000301109','nl0000289783','gb00b03mlx29','gb00b03mlx29','lu0197800237','nl0000289965']), + ).set_index(['household_id','asset_id']).reindex(columns=['male','wealth','name','share']) + assert_frame_equal(result,expected) + + result = household.join(portfolio, how='outer') + expected = concat([expected,DataFrame(dict(share = [1.00]), + index=MultiIndex.from_tuples([(4,np.nan)], + names=['household_id','asset_id']))], + axis=0).reindex(columns=expected.columns) + assert_frame_equal(result,expected) + + # invalid cases + household.index.name = 'foo' + def f(): + household.join(portfolio, how='inner') + self.assertRaises(ValueError, f) + + portfolio2 = portfolio.copy() + portfolio2.index.set_names(['household_id','foo']) + def f(): + portfolio2.join(portfolio, how='inner') + def _check_join(left, right, result, join_col, how='left', lsuffix='_x', rsuffix='_y'):