Skip to content

Commit 8a9fe43

Browse files
committed
BUG: reimplement MultiIndex.remove_unused_levels
* Add a large random test case for remove_unused_levels that failed the previous implementation * Fix #16556, a performance issue with the previous implementation * Always return at least a view instead of the original index
1 parent ee8346d commit 8a9fe43

File tree

3 files changed

+45
-20
lines changed

3 files changed

+45
-20
lines changed

doc/source/whatsnew/v0.20.2.txt

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Performance Improvements
3131
- Performance regression fix for MultiIndexes (:issue:`16319`, :issue:`16346`)
3232
- Improved performance of ``.clip()`` with scalar arguments (:issue:`15400`)
3333
- Improved performance of groupby with categorical groupers (:issue:`16413`)
34+
- Improved performance of ``MultiIndex.remove_unused_levels()`` (:issue:`16556`)
3435

3536
.. _whatsnew_0202.bug_fixes:
3637

@@ -62,6 +63,7 @@ Indexing
6263
^^^^^^^^
6364

6465
- Bug in ``DataFrame.reset_index(level=)`` with single level index (:issue:`16263`)
66+
- Bug in ``MultiIndex.remove_unused_levels()`` (:issue:`16556`)
6567

6668

6769
I/O

pandas/core/indexes/multi.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -1290,42 +1290,38 @@ def remove_unused_levels(self):
12901290
new_levels = []
12911291
new_labels = []
12921292

1293-
changed = np.ones(self.nlevels, dtype=bool)
1294-
for i, (lev, lab) in enumerate(zip(self.levels, self.labels)):
1293+
changed = False
1294+
for lev, lab in zip(self.levels, self.labels):
12951295

12961296
uniques = algos.unique(lab)
12971297

12981298
# nothing unused
12991299
if len(uniques) == len(lev):
13001300
new_levels.append(lev)
13011301
new_labels.append(lab)
1302-
changed[i] = False
13031302
continue
13041303

1305-
# set difference, then reverse sort
1306-
diff = Index(np.arange(len(lev))).difference(uniques)
1307-
unused = diff.sort_values(ascending=False)
1304+
changed = True
1305+
1306+
# labels get mapped from uniques to 0:len(uniques)
1307+
label_mapping = np.zeros(len(lev))
1308+
label_mapping[uniques] = np.arange(len(uniques))
1309+
lab = label_mapping[lab]
13081310

13091311
# new levels are simple
13101312
lev = lev.take(uniques)
13111313

1312-
# new labels, we remove the unsued
1313-
# by decrementing the labels for that value
1314-
# prob a better way
1315-
for u in unused:
1316-
1317-
lab = np.where(lab > u, lab - 1, lab)
1318-
13191314
new_levels.append(lev)
13201315
new_labels.append(lab)
13211316

1322-
# nothing changed
1323-
if not changed.any():
1324-
return self
1317+
result = self._shallow_copy()
13251318

1326-
return MultiIndex(new_levels, new_labels,
1327-
names=self.names, sortorder=self.sortorder,
1328-
verify_integrity=False)
1319+
if changed:
1320+
result._reset_identity()
1321+
result._set_levels(new_levels, validate=False)
1322+
result._set_labels(new_labels, validate=False)
1323+
1324+
return result
13291325

13301326
@property
13311327
def nlevels(self):

pandas/tests/indexes/test_multi.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -2489,7 +2489,34 @@ def test_reconstruct_remove_unused(self):
24892489
# idempotent
24902490
result2 = result.remove_unused_levels()
24912491
tm.assert_index_equal(result2, expected)
2492-
assert result2 is result
2492+
assert result2.is_(result)
2493+
2494+
def test_remove_unused_levels_large(self):
2495+
# GH16556
2496+
2497+
def check(first_type=None, second_type=None):
2498+
size = 1 << 16
2499+
first = np.random.randint(0, 1 << 13, size)
2500+
if first_type is not None:
2501+
first = first.astype(first_type)
2502+
second = np.random.randint(0, 1 << 10, size)
2503+
if second_type is not None:
2504+
second = second.astype(second_type)
2505+
third = np.random.rand(size)
2506+
df = DataFrame(dict(first=first, second=second, third=third))
2507+
df = df.groupby(['first', 'second']).sum()
2508+
df = df[df.third < 0.1]
2509+
2510+
result = df.index.remove_unused_levels()
2511+
assert len(result.levels[0]) < len(df.index.levels[0])
2512+
assert len(result.levels[1]) < len(df.index.levels[1])
2513+
assert result.equals(df.index)
2514+
2515+
expected = df.reset_index().set_index(['first', 'second']).index
2516+
assert result.equals(expected)
2517+
2518+
check()
2519+
check('datetime64[D]', 'str')
24932520

24942521
def test_isin(self):
24952522
values = [('foo', 2), ('bar', 3), ('quux', 4)]

0 commit comments

Comments
 (0)