Skip to content

Commit fe50bec

Browse files
committed
ENH: indexing object with MultiIndex works if indexing the top level. address GH #120
1 parent 425924e commit fe50bec

File tree

5 files changed

+52
-13
lines changed

5 files changed

+52
-13
lines changed

pandas/core/frame.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -793,12 +793,16 @@ def _slice(self, slobj, axis=0):
793793

794794
def _getitem_multilevel(self, key):
795795
loc = self.columns.get_loc(key)
796-
if isinstance(loc, slice):
796+
if isinstance(loc, (slice, np.ndarray)):
797797
new_columns = self.columns[loc]
798-
new_columns = _maybe_droplevels(new_columns, key)
799-
new_values = self.values[:, loc]
800-
result = DataFrame(new_values, index=self.index,
801-
columns=new_columns)
798+
result_columns = _maybe_droplevels(new_columns, key)
799+
if self._is_mixed_type:
800+
result = self.reindex(columns=new_columns)
801+
result.columns = result_columns
802+
else:
803+
new_values = self.values[:, loc]
804+
result = DataFrame(new_values, index=self.index,
805+
columns=result_columns)
802806
return result
803807
else:
804808
return self._getitem_single(key)

pandas/core/index.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -965,11 +965,13 @@ def get_loc(self, key):
965965
labels = self.labels[0]
966966
loc = level.get_loc(key)
967967

968-
assert(self.lexsort_depth >= 1)
969-
970-
i = labels.searchsorted(loc, side='left')
971-
j = labels.searchsorted(loc, side='right')
972-
return slice(i, j)
968+
if self.lexsort_depth == 0:
969+
return labels == loc
970+
else:
971+
# sorted, so can return slice object -> view
972+
i = labels.searchsorted(loc, side='left')
973+
j = labels.searchsorted(loc, side='right')
974+
return slice(i, j)
973975

974976
def _get_tuple_loc(self, tup):
975977
indexer = self._get_label_key(tup)

pandas/core/internals.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ def xs(self, key, axis=1, copy=True):
438438

439439
new_axes = list(self.axes)
440440

441-
if isinstance(loc, slice):
441+
# could be an array indexer!
442+
if isinstance(loc, (slice, np.ndarray)):
442443
new_axes[axis] = new_axes[axis][loc]
443444
else:
444445
new_axes.pop(axis)

pandas/core/series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def _multilevel_index(self, key):
271271
values = self.values
272272
try:
273273
loc = self.index.get_loc(key)
274-
if isinstance(loc, slice):
274+
if isinstance(loc, (slice, np.ndarray)):
275275
# TODO: what if a level contains tuples??
276276
new_index = self.index[loc]
277277
new_index = _maybe_droplevels(new_index, key)

pandas/tests/test_multilevel.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,39 @@ def test_frame_getitem_view(self):
276276
def test_frame_getitem_not_sorted(self):
277277
df = self.frame.T
278278
df['foo', 'four'] = 'foo'
279-
self.assertRaises(Exception, df.__getitem__, 'foo')
279+
280+
arrays = [np.array(x) for x in zip(*df.columns.get_tuple_index())]
281+
282+
result = df['foo']
283+
result2 = df.ix[:, 'foo']
284+
expected = df.reindex(columns=df.columns[arrays[0] == 'foo'])
285+
expected.columns = expected.columns.droplevel(0)
286+
assert_frame_equal(result, expected)
287+
assert_frame_equal(result2, expected)
288+
289+
df = df.T
290+
result = df.xs('foo')
291+
result2 = df.ix['foo']
292+
expected = df.reindex(df.index[arrays[0] == 'foo'])
293+
expected.index = expected.index.droplevel(0)
294+
assert_frame_equal(result, expected)
295+
assert_frame_equal(result2, expected)
296+
297+
def test_series_getitem_not_sorted(self):
298+
arrays = [['bar', 'bar', 'baz', 'baz', 'qux', 'qux', 'foo', 'foo'],
299+
['one', 'two', 'one', 'two', 'one', 'two', 'one', 'two']]
300+
tuples = zip(*arrays)
301+
index = MultiIndex.from_tuples(tuples)
302+
s = Series(randn(8), index=index)
303+
304+
arrays = [np.array(x) for x in zip(*index.get_tuple_index())]
305+
306+
result = s['qux']
307+
result2 = s.ix['qux']
308+
expected = s[arrays[0] == 'qux']
309+
expected.index = expected.index.droplevel(0)
310+
assert_series_equal(result, expected)
311+
assert_series_equal(result2, expected)
280312

281313
if __name__ == '__main__':
282314

0 commit comments

Comments
 (0)