Skip to content

Commit 9852c6f

Browse files
committed
BUG: reimplement MultiIndex.remove_unused_levels
* Add a large random test case for remove_unused_levels that failed the previous implementation * Fix #16556, a performance issue with the previous implementation * Always return at least a view instead of the original index
1 parent fb47ee5 commit 9852c6f

File tree

3 files changed

+45
-20
lines changed

3 files changed

+45
-20
lines changed

doc/source/whatsnew/v0.20.2.txt

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Performance Improvements
3232
- Performance regression fix for MultiIndexes (:issue:`16319`, :issue:`16346`)
3333
- Improved performance of ``.clip()`` with scalar arguments (:issue:`15400`)
3434
- Improved performance of groupby with categorical groupers (:issue:`16413`)
35+
- Improved performance of ``MultiIndex.remove_unused_levels()`` (:issue:`16556`)
3536

3637
.. _whatsnew_0202.bug_fixes:
3738

@@ -61,6 +62,7 @@ Indexing
6162

6263
- Bug in ``DataFrame.reset_index(level=)`` with single level index (:issue:`16263`)
6364
- Bug in partial string indexing with a monotonic, but not strictly-monotonic, index incorrectly reversing the slice bounds (:issue:`16515`)
65+
- Bug in ``MultiIndex.remove_unused_levels()`` (:issue:`16556`)
6466

6567
I/O
6668
^^^

pandas/core/indexes/multi.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -1290,42 +1290,38 @@ def remove_unused_levels(self):
12901290
new_levels = []
12911291
new_labels = []
12921292

1293-
changed = np.ones(self.nlevels, dtype=bool)
1294-
for i, (lev, lab) in enumerate(zip(self.levels, self.labels)):
1293+
changed = False
1294+
for lev, lab in zip(self.levels, self.labels):
12951295

12961296
uniques = algos.unique(lab)
12971297

12981298
# nothing unused
12991299
if len(uniques) == len(lev):
13001300
new_levels.append(lev)
13011301
new_labels.append(lab)
1302-
changed[i] = False
13031302
continue
13041303

1305-
# set difference, then reverse sort
1306-
diff = Index(np.arange(len(lev))).difference(uniques)
1307-
unused = diff.sort_values(ascending=False)
1304+
changed = True
1305+
1306+
# labels get mapped from uniques to 0:len(uniques)
1307+
label_mapping = np.zeros(len(lev))
1308+
label_mapping[uniques] = np.arange(len(uniques))
1309+
lab = label_mapping[lab]
13081310

13091311
# new levels are simple
13101312
lev = lev.take(uniques)
13111313

1312-
# new labels, we remove the unsued
1313-
# by decrementing the labels for that value
1314-
# prob a better way
1315-
for u in unused:
1316-
1317-
lab = np.where(lab > u, lab - 1, lab)
1318-
13191314
new_levels.append(lev)
13201315
new_labels.append(lab)
13211316

1322-
# nothing changed
1323-
if not changed.any():
1324-
return self
1317+
result = self._shallow_copy()
13251318

1326-
return MultiIndex(new_levels, new_labels,
1327-
names=self.names, sortorder=self.sortorder,
1328-
verify_integrity=False)
1319+
if changed:
1320+
result._reset_identity()
1321+
result._set_levels(new_levels, validate=False)
1322+
result._set_labels(new_labels, validate=False)
1323+
1324+
return result
13291325

13301326
@property
13311327
def nlevels(self):

pandas/tests/indexes/test_multi.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -2515,7 +2515,34 @@ def test_reconstruct_remove_unused(self):
25152515
# idempotent
25162516
result2 = result.remove_unused_levels()
25172517
tm.assert_index_equal(result2, expected)
2518-
assert result2 is result
2518+
assert result2.is_(result)
2519+
2520+
@pytest.mark.parametrize('first_type,second_type', [
2521+
('int64', 'int64'),
2522+
('datetime64[D]', 'str')])
2523+
def test_remove_unused_levels_large(self, first_type, second_type):
2524+
# GH16556
2525+
2526+
# because tests should be deterministic (and this test in particular
2527+
# checks that levels are removed, which is not the case for every
2528+
# random input):
2529+
rng = np.random.RandomState(4) # seed is arbitrary value that works
2530+
2531+
size = 1 << 16
2532+
df = DataFrame(dict(
2533+
first=rng.randint(0, 1 << 13, size).astype(first_type),
2534+
second=rng.randint(0, 1 << 10, size).astype(second_type),
2535+
third=rng.rand(size)))
2536+
df = df.groupby(['first', 'second']).sum()
2537+
df = df[df.third < 0.1]
2538+
2539+
result = df.index.remove_unused_levels()
2540+
assert len(result.levels[0]) < len(df.index.levels[0])
2541+
assert len(result.levels[1]) < len(df.index.levels[1])
2542+
assert result.equals(df.index)
2543+
2544+
expected = df.reset_index().set_index(['first', 'second']).index
2545+
tm.assert_index_equal(result, expected)
25192546

25202547
def test_isin(self):
25212548
values = [('foo', 2), ('bar', 3), ('quux', 4)]

0 commit comments

Comments
 (0)