Skip to content

Commit 41ec919

Browse files
committed
ENH: hack toward #629
1 parent ac26c84 commit 41ec919

File tree

4 files changed

+109
-22
lines changed

4 files changed

+109
-22
lines changed

pandas/core/frame.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ def pop(self, item):
13951395
def _series(self):
13961396
return self._data.get_series_dict()
13971397

1398-
def xs(self, key, axis=0, copy=True):
1398+
def xs(self, key, axis=0, level=None, copy=True):
13991399
"""
14001400
Returns a cross-section (row or column) from the DataFrame as a Series
14011401
object. Defaults to returning a row (axis 0)
@@ -1413,6 +1413,15 @@ def xs(self, key, axis=0, copy=True):
14131413
-------
14141414
xs : Series
14151415
"""
1416+
labels = self._get_axis(axis)
1417+
if level is not None:
1418+
indexer = [slice(None, None)] * 2
1419+
indexer[axis] = labels.get_loc_level(key, level=level)
1420+
result = self.ix[tuple(indexer)]
1421+
new_ax = result._get_axis(axis).droplevel(level)
1422+
setattr(result, result._get_axis_name(axis), new_ax)
1423+
return result
1424+
14161425
if axis == 1:
14171426
data = self[key]
14181427
if copy:

pandas/core/index.py

+58-9
Original file line numberDiff line numberDiff line change
@@ -1608,22 +1608,71 @@ def get_loc(self, key):
16081608
if len(key) == self.nlevels:
16091609
return self._engine.get_loc(key)
16101610
else:
1611+
# partial selection
16111612
result = slice(*self.slice_locs(key, key))
16121613
if result.start == result.stop:
16131614
raise KeyError(key)
16141615
return result
16151616
else:
1616-
level = self.levels[0]
1617-
labels = self.labels[0]
1618-
loc = level.get_loc(key)
1617+
return self._get_level_indexer(key, level=0)
16191618

1620-
if self.lexsort_depth == 0:
1621-
return labels == loc
1619+
def get_loc_level(self, key, level=0):
1620+
"""
1621+
Get integer location slice for requested label or tuple
1622+
1623+
Parameters
1624+
----------
1625+
key : label or tuple
1626+
1627+
Returns
1628+
-------
1629+
loc : int or slice object
1630+
"""
1631+
if isinstance(key, tuple) and level == 0:
1632+
if not any(isinstance(k, slice) for k in key):
1633+
if len(key) == self.nlevels:
1634+
return self._engine.get_loc(key)
1635+
else:
1636+
# partial selection
1637+
result = slice(*self.slice_locs(key, key))
1638+
if result.start == result.stop:
1639+
raise KeyError(key)
1640+
return result
16221641
else:
1623-
# sorted, so can return slice object -> view
1624-
i = labels.searchsorted(loc, side='left')
1625-
j = labels.searchsorted(loc, side='right')
1626-
return slice(i, j)
1642+
indexer = None
1643+
for i, k in enumerate(key):
1644+
if k is None:
1645+
continue
1646+
1647+
if isinstance(k, slice):
1648+
if k == slice(None, None):
1649+
continue
1650+
else:
1651+
k_index = np.empty(len(self), dtype=bool)
1652+
k_index[k] = True
1653+
else:
1654+
k_index = self._get_level_indexer(k, level=i)
1655+
1656+
if indexer is None:
1657+
indexer = k_index
1658+
else:
1659+
indexer &= k_index
1660+
return indexer
1661+
else:
1662+
return self._get_level_indexer(key, level=level)
1663+
1664+
def _get_level_indexer(self, key, level=0):
1665+
level_index = self.levels[level]
1666+
loc = level_index.get_loc(key)
1667+
labels = self.labels[level]
1668+
1669+
if level > 0 or self.lexsort_depth == 0:
1670+
return labels == loc
1671+
else:
1672+
# sorted, so can return slice object -> view
1673+
i = labels.searchsorted(loc, side='left')
1674+
j = labels.searchsorted(loc, side='right')
1675+
return slice(i, j)
16271676

16281677
def truncate(self, before=None, after=None):
16291678
"""

pandas/core/series.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,8 @@ def _get_with(self, key):
308308
indexer = self.ix._convert_to_indexer(key, axis=0)
309309
return self._get_values(indexer)
310310
else:
311-
# mpl hackaround
312311
if isinstance(key, tuple):
313-
try:
314-
return self._get_values(key)
315-
except Exception:
316-
pass
312+
return self._get_values_tuple(key)
317313

318314
if not isinstance(key, (list, np.ndarray)):
319315
key = list(key)
@@ -338,6 +334,33 @@ def _get_with(self, key):
338334
return self._get_values(key)
339335
raise
340336

337+
def _get_values_tuple(self, key):
338+
# mpl hackaround
339+
if any(k is None for k in key):
340+
return self._get_values(key)
341+
342+
if not isinstance(self.index, MultiIndex):
343+
raise ValueError('Can only tuple-index with a MultiIndex')
344+
345+
indexer = self.index.get_loc_level(key)
346+
result = self._get_values(indexer)
347+
348+
# kludgearound
349+
new_index = result.index
350+
for i, k in reversed(list(enumerate(key))):
351+
if k != slice(None, None):
352+
new_index = new_index.droplevel(i)
353+
result.index = new_index
354+
355+
return result
356+
357+
def _get_values(self, indexer):
358+
try:
359+
return Series(self.values[indexer], index=self.index[indexer],
360+
name=self.name)
361+
except Exception:
362+
return self.values[indexer]
363+
341364
def __setitem__(self, key, value):
342365
values = self.values
343366
try:
@@ -397,13 +420,6 @@ def _set_labels(self, key, value):
397420
% str(key[mask]))
398421
self._set_values(indexer, value)
399422

400-
def _get_values(self, indexer):
401-
try:
402-
return Series(self.values[indexer], index=self.index[indexer],
403-
name=self.name)
404-
except Exception:
405-
return self.values[indexer]
406-
407423
def _set_values(self, key, value):
408424
self.values[key] = value
409425

pandas/tests/test_multilevel.py

+13
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,19 @@ def test_xs_partial(self):
203203
assert_frame_equal(result, expected)
204204
assert_frame_equal(result, result2)
205205

206+
def test_xs_level(self):
207+
result = self.frame.xs('two', level=1)
208+
expected = self.frame[self.frame.index.get_level_values(1) == 'two']
209+
expected.index = expected.index.droplevel(1)
210+
211+
assert_frame_equal(result, expected)
212+
213+
def test_xs_level_series(self):
214+
s = self.frame['A']
215+
result = s[:, 'two']
216+
expected = self.frame.xs('two', level=1)['A']
217+
assert_series_equal(result, expected)
218+
206219
def test_fancy_2d(self):
207220
result = self.frame.ix['foo', 'B']
208221
expected = self.frame.xs('foo')['B']

0 commit comments

Comments
 (0)