Skip to content

Commit 090d6a1

Browse files
authored
ENH: preserve RangeIndex in factorize (#38034)
* ENH: preserve RangeIndex in factorize * dedoc * ensure np.intp * 32bit compat * if->elif
1 parent 78d1498 commit 090d6a1

File tree

5 files changed

+66
-2
lines changed

5 files changed

+66
-2
lines changed

pandas/core/algorithms.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ABCExtensionArray,
5252
ABCIndexClass,
5353
ABCMultiIndex,
54+
ABCRangeIndex,
5455
ABCSeries,
5556
)
5657
from pandas.core.dtypes.missing import isna, na_value_for_dtype
@@ -682,7 +683,9 @@ def factorize(
682683
na_sentinel = -1
683684
dropna = False
684685

685-
if is_extension_array_dtype(values.dtype):
686+
if isinstance(values, ABCRangeIndex):
687+
return values.factorize(sort=sort)
688+
elif is_extension_array_dtype(values.dtype):
686689
values = extract_array(values)
687690
codes, uniques = values.factorize(na_sentinel=na_sentinel)
688691
dtype = original.dtype

pandas/core/indexes/range.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import timedelta
22
import operator
33
from sys import getsizeof
4-
from typing import Any, List
4+
from typing import Any, List, Optional, Tuple
55
import warnings
66

77
import numpy as np
@@ -461,6 +461,16 @@ def argsort(self, *args, **kwargs) -> np.ndarray:
461461
else:
462462
return np.arange(len(self) - 1, -1, -1)
463463

464+
def factorize(
465+
self, sort: bool = False, na_sentinel: Optional[int] = -1
466+
) -> Tuple[np.ndarray, "RangeIndex"]:
467+
codes = np.arange(len(self), dtype=np.intp)
468+
uniques = self
469+
if sort and self.step < 0:
470+
codes = codes[::-1]
471+
uniques = uniques[::-1]
472+
return codes, uniques
473+
464474
def equals(self, other: object) -> bool:
465475
"""
466476
Determines if two Index objects contain the same elements.

pandas/tests/arrays/categorical/test_constructors.py

+10
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,16 @@ def test_constructor_with_generator(self):
290290
cat = Categorical([0, 1, 2], categories=range(3))
291291
tm.assert_categorical_equal(cat, exp)
292292

293+
def test_constructor_with_rangeindex(self):
294+
# RangeIndex is preserved in Categories
295+
rng = Index(range(3))
296+
297+
cat = Categorical(rng)
298+
tm.assert_index_equal(cat.categories, rng, exact=True)
299+
300+
cat = Categorical([1, 2, 0], categories=rng)
301+
tm.assert_index_equal(cat.categories, rng, exact=True)
302+
293303
@pytest.mark.parametrize(
294304
"dtl",
295305
[

pandas/tests/indexes/multi/test_constructors.py

+8
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,14 @@ def test_from_product_datetimeindex():
477477
tm.assert_numpy_array_equal(mi.values, etalon)
478478

479479

480+
def test_from_product_rangeindex():
481+
# RangeIndex is preserved by factorize, so preserved in levels
482+
rng = Index(range(5))
483+
other = ["a", "b"]
484+
mi = MultiIndex.from_product([rng, other])
485+
tm.assert_index_equal(mi._levels[0], rng, exact=True)
486+
487+
480488
@pytest.mark.parametrize("ordered", [False, True])
481489
@pytest.mark.parametrize("f", [lambda x: x, lambda x: Series(x), lambda x: x.values])
482490
def test_from_product_index_series_categorical(ordered, f):

pandas/tests/test_algos.py

+33
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,39 @@ def test_datetime64_factorize(self, writable):
307307
tm.assert_numpy_array_equal(codes, expected_codes)
308308
tm.assert_numpy_array_equal(uniques, expected_uniques)
309309

310+
@pytest.mark.parametrize("sort", [True, False])
311+
def test_factorize_rangeindex(self, sort):
312+
# increasing -> sort doesn't matter
313+
ri = pd.RangeIndex.from_range(range(10))
314+
expected = np.arange(10, dtype=np.intp), ri
315+
316+
result = algos.factorize(ri, sort=sort)
317+
tm.assert_numpy_array_equal(result[0], expected[0])
318+
tm.assert_index_equal(result[1], expected[1], exact=True)
319+
320+
result = ri.factorize(sort=sort)
321+
tm.assert_numpy_array_equal(result[0], expected[0])
322+
tm.assert_index_equal(result[1], expected[1], exact=True)
323+
324+
@pytest.mark.parametrize("sort", [True, False])
325+
def test_factorize_rangeindex_decreasing(self, sort):
326+
# decreasing -> sort matters
327+
ri = pd.RangeIndex.from_range(range(10))
328+
expected = np.arange(10, dtype=np.intp), ri
329+
330+
ri2 = ri[::-1]
331+
expected = expected[0], ri2
332+
if sort:
333+
expected = expected[0][::-1], expected[1][::-1]
334+
335+
result = algos.factorize(ri2, sort=sort)
336+
tm.assert_numpy_array_equal(result[0], expected[0])
337+
tm.assert_index_equal(result[1], expected[1], exact=True)
338+
339+
result = ri2.factorize(sort=sort)
340+
tm.assert_numpy_array_equal(result[0], expected[0])
341+
tm.assert_index_equal(result[1], expected[1], exact=True)
342+
310343
def test_deprecate_order(self):
311344
# gh 19727 - check warning is raised for deprecated keyword, order.
312345
# Test not valid once order keyword is removed.

0 commit comments

Comments
 (0)