Skip to content

Commit 4231e0c

Browse files
committed
BUG: reimplement MultiIndex.remove_unused_levels
* Add a large random test case for remove_unused_levels that failed the previous implementation * Fix #16556, a performance issue with the previous implementation * Add inplace functionality * Always return (if not inplace) at least a view instead of the original index
1 parent ee8346d commit 4231e0c

File tree

2 files changed

+46
-21
lines changed

2 files changed

+46
-21
lines changed

pandas/core/indexes/multi.py

+25-20
Original file line numberDiff line numberDiff line change
@@ -1252,7 +1252,7 @@ def _sort_levels_monotonic(self):
12521252
names=self.names, sortorder=self.sortorder,
12531253
verify_integrity=False)
12541254

1255-
def remove_unused_levels(self):
1255+
def remove_unused_levels(self, inplace=False):
12561256
"""
12571257
create a new MultiIndex from the current that removing
12581258
unused levels, meaning that they are not expressed in the labels
@@ -1263,6 +1263,11 @@ def remove_unused_levels(self):
12631263
12641264
.. versionadded:: 0.20.0
12651265
1266+
Parameters
1267+
----------
1268+
inplace : bool
1269+
if True, mutates in place
1270+
12661271
Returns
12671272
-------
12681273
MultiIndex
@@ -1290,42 +1295,42 @@ def remove_unused_levels(self):
12901295
new_levels = []
12911296
new_labels = []
12921297

1293-
changed = np.ones(self.nlevels, dtype=bool)
1294-
for i, (lev, lab) in enumerate(zip(self.levels, self.labels)):
1298+
changed = False
1299+
for lev, lab in zip(self.levels, self.labels):
12951300

12961301
uniques = algos.unique(lab)
12971302

12981303
# nothing unused
12991304
if len(uniques) == len(lev):
13001305
new_levels.append(lev)
13011306
new_labels.append(lab)
1302-
changed[i] = False
13031307
continue
13041308

1305-
# set difference, then reverse sort
1306-
diff = Index(np.arange(len(lev))).difference(uniques)
1307-
unused = diff.sort_values(ascending=False)
1309+
changed = True
1310+
1311+
# labels get mapped from uniques to 0:len(uniques)
1312+
label_mapping = np.zeros(len(lev))
1313+
label_mapping[uniques] = np.arange(len(uniques))
1314+
lab = label_mapping[lab]
13081315

13091316
# new levels are simple
13101317
lev = lev.take(uniques)
13111318

1312-
# new labels, we remove the unsued
1313-
# by decrementing the labels for that value
1314-
# prob a better way
1315-
for u in unused:
1316-
1317-
lab = np.where(lab > u, lab - 1, lab)
1318-
13191319
new_levels.append(lev)
13201320
new_labels.append(lab)
1321+
1322+
if inplace:
1323+
idx = self
1324+
else:
1325+
idx = self._shallow_copy()
13211326

1322-
# nothing changed
1323-
if not changed.any():
1324-
return self
1327+
if changed:
1328+
idx._reset_identity()
1329+
idx._set_levels(new_levels, validate=False)
1330+
idx._set_labels(new_labels, validate=False)
13251331

1326-
return MultiIndex(new_levels, new_labels,
1327-
names=self.names, sortorder=self.sortorder,
1328-
verify_integrity=False)
1332+
if not inplace:
1333+
return idx
13291334

13301335
@property
13311336
def nlevels(self):

pandas/tests/indexes/test_multi.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -2489,7 +2489,27 @@ def test_reconstruct_remove_unused(self):
24892489
# idempotent
24902490
result2 = result.remove_unused_levels()
24912491
tm.assert_index_equal(result2, expected)
2492-
assert result2 is result
2492+
assert result2.is_(result)
2493+
2494+
def test_remove_unused_levels_large(self):
2495+
# because tests should be deterministic:
2496+
rng = np.random.RandomState(4) # chosen by fair dice roll. guaranteed to be random.
2497+
2498+
size = 1<<16
2499+
df = DataFrame(dict(first=rng.randint(0, 1<<13, size),
2500+
second=rng.randint(0, 1<<10, size),
2501+
third=rng.rand(size)))
2502+
df = df.groupby(['first', 'second']).sum()
2503+
df = df[df.third < 0.1]
2504+
2505+
result = df.index.remove_unused_levels()
2506+
assert len(result.levels[0]) < len(df.index.levels[0])
2507+
assert len(result.levels[1]) < len(df.index.levels[1])
2508+
assert result.equals(df.index)
2509+
2510+
# in place
2511+
df.index.remove_unused_levels(inplace=True)
2512+
tm.assert_index_equal(df.index, result)
24932513

24942514
def test_isin(self):
24952515
values = [('foo', 2), ('bar', 3), ('quux', 4)]

0 commit comments

Comments
 (0)