Skip to content

Commit 5dc1afc

Browse files
committed
ENH: Index.reindex refactoring and getting drop to work with MultiIndex. address GH #101
1 parent be6b004 commit 5dc1afc

File tree

6 files changed

+86
-35
lines changed

6 files changed

+86
-35
lines changed

pandas/core/generic.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -157,17 +157,9 @@ def drop(self, labels, axis=0):
157157
"""
158158
axis_name = self._get_axis_name(axis)
159159
axis = self._get_axis(axis)
160-
161-
labels = np.asarray(list(labels), dtype=object)
162-
163-
indexer, mask = axis.get_indexer(labels)
164-
if not mask.all():
165-
raise ValueError('labels %s not contained in axis' % labels[-mask])
166-
167-
new_axis = np.delete(np.asarray(axis), indexer)
160+
new_axis = axis.drop(labels)
168161
return self.reindex(**{axis_name : new_axis})
169162

170-
171163
class NDFrame(PandasObject):
172164
"""
173165

pandas/core/index.py

+60-11
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ def get_indexer(self, target, method=None):
268268
target.indexMap, method)
269269
return indexer, mask
270270

271+
def reindex(self, target, method=None):
272+
indexer, mask = self.get_indexer(target, method=method)
273+
return target, indexer, mask
274+
271275
def slice_locs(self, start=None, end=None):
272276
"""
273277
@@ -300,6 +304,13 @@ def delete(self, loc):
300304
arr = np.delete(np.asarray(self), loc)
301305
return Index(arr)
302306

307+
def drop(self, labels):
308+
labels = np.asarray(list(labels), dtype=object)
309+
indexer, mask = self.get_indexer(labels)
310+
if not mask.all():
311+
raise ValueError('labels %s not contained in axis' % labels[-mask])
312+
return self.delete(indexer)
313+
303314
class DateIndex(Index):
304315
pass
305316

@@ -486,6 +497,25 @@ def take(self, *args, **kwargs):
486497
new_labels = [lab.take(*args, **kwargs) for lab in self.labels]
487498
return MultiIndex(levels=self.levels, labels=new_labels)
488499

500+
def drop(self, labels):
501+
try:
502+
arr = np.asarray(list(labels), dtype=object)
503+
indexer, mask = self.get_indexer(arr)
504+
if not mask.all():
505+
raise ValueError('labels %s not contained in axis' % arr[-mask])
506+
except Exception:
507+
pass
508+
509+
inds = []
510+
for label in labels:
511+
loc = self.get_loc(label)
512+
if isinstance(loc, int):
513+
inds.append(loc)
514+
else:
515+
inds.extend(range(loc.start, loc.stop))
516+
517+
return self.delete(inds)
518+
489519
def droplevel(self, level=0):
490520
"""
491521
Return Index with requested level removed. If MultiIndex has only 2
@@ -556,17 +586,32 @@ def get_indexer(self, target, method=None):
556586
}
557587
method = aliases.get(method, method)
558588

559-
if not isinstance(target, MultiIndex):
560-
raise TypeError('Can only align with other MultiIndex objects')
589+
if isinstance(target, MultiIndex):
590+
target_index = target.get_tuple_index()
591+
else:
592+
if len(target) > 0:
593+
val = target[0]
594+
if not isinstance(val, tuple) or len(val) != self.nlevels:
595+
raise ValueError('can only pass MultiIndex or '
596+
'array of tuples')
561597

562-
self_index = self.get_tuple_index()
563-
target_index = target.get_tuple_index()
598+
target_index = target
564599

600+
self_index = self.get_tuple_index()
565601
indexer, mask = _tseries.getFillVec(self_index, target_index,
566602
self_index.indexMap,
567-
target_index.indexMap, method)
603+
target.indexMap, method)
568604
return indexer, mask
569605

606+
def reindex(self, target, method=None):
607+
indexer, mask = self.get_indexer(target, method=method)
608+
609+
# hopefully?
610+
if not isinstance(target, MultiIndex):
611+
target = MultiIndex.from_tuples(target)
612+
613+
return target, indexer, mask
614+
570615
def get_tuple_index(self):
571616
return Index(list(self))
572617

@@ -699,14 +744,18 @@ def equals(self, other):
699744
if len(self) != len(other):
700745
return False
701746

702-
# if not self.equal_levels(other):
703-
# return False
704-
705747
for i in xrange(self.nlevels):
706-
if not self.levels[i].equals(other.levels[i]):
707-
return False
708-
if not np.array_equal(self.labels[i], other.labels[i]):
748+
svalues = np.asarray(self.levels[i]).take(self.labels[i])
749+
ovalues = np.asarray(other.levels[i]).take(other.labels[i])
750+
751+
if not np.array_equal(svalues, ovalues):
709752
return False
753+
754+
# if not self.levels[i].equals(other.levels[i]):
755+
# return False
756+
# if not np.array_equal(self.labels[i], other.labels[i]):
757+
# return False
758+
710759
return True
711760

712761
def equal_levels(self, other):

pandas/core/internals.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def reindex_items_from(self, new_ref_items):
124124
-------
125125
reindexed : Block
126126
"""
127-
indexer, mask = self.items.get_indexer(new_ref_items)
127+
new_ref_items, indexer, mask = self.items.reindex(new_ref_items)
128128
masked_idx = indexer[mask]
129129
new_values = self.values.take(masked_idx, axis=0)
130130
new_items = self.items.take(masked_idx)
@@ -525,7 +525,10 @@ def _delete_from_block(self, i, item):
525525

526526
def _add_new_block(self, item, value):
527527
# Do we care about dtype at the moment?
528-
new_block = make_block(value, [item], self.items)
528+
529+
# hm, elaborate hack?
530+
loc = self.items.get_loc(item)
531+
new_block = make_block(value, self.items[loc:loc+1], self.items)
529532
self.blocks.append(new_block)
530533

531534
def _find_block(self, item):
@@ -546,7 +549,7 @@ def reindex_axis(self, new_axis, method=None, axis=0):
546549
new_axis = _ensure_index(new_axis)
547550
cur_axis = self.axes[axis]
548551

549-
indexer, mask = cur_axis.get_indexer(new_axis, method)
552+
new_axis, indexer, mask = cur_axis.reindex(new_axis, method)
550553

551554
# TODO: deal with length-0 case? or does it fall out?
552555
notmask = -mask
@@ -572,16 +575,16 @@ def reindex_items(self, new_items):
572575
data = data.consolidate()
573576
return data.reindex_items(new_items)
574577

578+
# TODO: this part could be faster (!)
579+
new_items, _, mask = self.items.reindex(new_items)
580+
notmask = -mask
581+
575582
new_blocks = []
576583
for block in self.blocks:
577584
newb = block.reindex_items_from(new_items)
578585
if len(newb.items) > 0:
579586
new_blocks.append(newb)
580587

581-
# TODO: this part could be faster (!)
582-
_, mask = self.items.get_indexer(new_items)
583-
notmask = -mask
584-
585588
if notmask.any():
586589
extra_items = new_items[notmask]
587590

pandas/core/series.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def reindex(self, index=None, method=None):
10611061
if len(self.index) == 0:
10621062
return Series(nan, index=index)
10631063

1064-
fill_vec, mask = self.index.get_indexer(index, method=method)
1064+
new_index, fill_vec, mask = self.index.reindex(index, method=method)
10651065
new_values = self.values.take(fill_vec)
10661066

10671067
notmask = -mask
@@ -1073,7 +1073,7 @@ def reindex(self, index=None, method=None):
10731073

10741074
np.putmask(new_values, notmask, nan)
10751075

1076-
return Series(new_values, index=index)
1076+
return Series(new_values, index=new_index)
10771077

10781078
def reindex_like(self, other, method=None):
10791079
"""

pandas/tests/test_index.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,13 @@ def test_get_indexer(self):
496496
assert_almost_equal(r2, rbfill2)
497497

498498
# pass non-MultiIndex
499+
r1, r2 = idx1.get_indexer(idx2.get_tuple_index())
500+
rexp1, rexp2 = idx1.get_indexer(idx2)
501+
assert_almost_equal(r1, rexp1)
502+
assert_almost_equal(r2, rexp2)
503+
499504
self.assertRaises(Exception, idx1.get_indexer,
500-
idx2.get_tuple_index())
505+
list(zip(*idx2.get_tuple_index())[0]))
501506

502507
def test_format(self):
503508
self.index.format()

pandas/tests/test_multilevel.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
# pylint: disable-msg=W0612,E1101
2-
from copy import deepcopy
3-
from datetime import datetime, timedelta
42
from cStringIO import StringIO
5-
import cPickle as pickle
63
import operator
7-
import os
84
import unittest
95

106
from numpy import random, nan
@@ -52,6 +48,11 @@ def _test_roundtrip(frame):
5248
_test_roundtrip(self.ymd)
5349
_test_roundtrip(self.ymd.T)
5450

51+
def test_reindex(self):
52+
reindexed = self.frame.ix[[('foo', 'one'), ('bar', 'one')]]
53+
expected = self.frame.ix[[0, 3]]
54+
assert_frame_equal(reindexed, expected)
55+
5556
def test_repr_to_string(self):
5657
repr(self.frame)
5758
repr(self.ymd)
@@ -160,7 +161,8 @@ def test_sortlevel_mixed(self):
160161
dft['foo', 'three'] = 'bar'
161162

162163
sorted_after = dft.sortlevel(1, axis=1)
163-
assert_frame_equal(sorted_before, sorted_after.drop(['foo'], axis=1))
164+
assert_frame_equal(sorted_before.drop([('foo', 'three')], axis=1),
165+
sorted_after.drop([('foo', 'three')], axis=1))
164166

165167
def test_alignment(self):
166168
pass

0 commit comments

Comments
 (0)