Skip to content

Commit 21dd222

Browse files
PERF: Do not init cache in RangeIndex.take
Improve performance when passing an array to RangeIndex.take, DataFrame.loc, or DataFrame.iloc and the DataFrame is using a RangeIndex
1 parent 3863a48 commit 21dd222

File tree

5 files changed

+100
-1
lines changed

5 files changed

+100
-1
lines changed

doc/source/whatsnew/v2.0.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Bug fixes
4646

4747
Other
4848
~~~~~
49+
- Improved performance when passing an array to :meth:`RangeIndex.take`, :meth:`DataFrame.loc`, or :meth:`DataFrame.iloc` and the DataFrame is using a RangeIndex (:issue:`53387`)
4950
- Raised a better error message when calling :func:`Series.dt.to_pydatetime` with :class:`ArrowDtype` with ``pyarrow.date32`` or ``pyarrow.date64`` type (:issue:`52812`)
5051

5152
.. ---------------------------------------------------------------------------

pandas/core/indexes/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,6 @@ def take(
11541154
taken = values.take(
11551155
indices, allow_fill=allow_fill, fill_value=self._na_value
11561156
)
1157-
# _constructor so RangeIndex-> Index with an int64 dtype
11581157
return self._constructor._simple_new(taken, name=self.name)
11591158

11601159
@final
@@ -5504,6 +5503,10 @@ def equals(self, other: Any) -> bool:
55045503
if not isinstance(other, Index):
55055504
return False
55065505

5506+
if len(self) != len(other):
5507+
# quickly return if the lengths are different
5508+
return False
5509+
55075510
if is_object_dtype(self.dtype) and not is_object_dtype(other.dtype):
55085511
# if other is not object, use other's logic for coercion
55095512
return other.equals(self)

pandas/core/indexes/range.py

+41
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
if TYPE_CHECKING:
5151
from pandas._typing import (
52+
Axis,
5253
Dtype,
5354
NaPosition,
5455
Self,
@@ -1037,3 +1038,43 @@ def _arith_method(self, other, op):
10371038
except (ValueError, TypeError, ZeroDivisionError):
10381039
# test_arithmetic_explicit_conversions
10391040
return super()._arith_method(other, op)
1041+
1042+
def take(
1043+
self,
1044+
indices,
1045+
axis: Axis = 0,
1046+
allow_fill: bool = True,
1047+
fill_value=None,
1048+
**kwargs,
1049+
):
1050+
if kwargs:
1051+
nv.validate_take((), kwargs)
1052+
if is_scalar(indices):
1053+
raise TypeError("Expected indices to be array-like")
1054+
indices = ensure_platform_int(indices)
1055+
allow_fill = self._maybe_disallow_fill(allow_fill, fill_value, indices)
1056+
assert allow_fill is False, "allow_fill isn't supported by RangeIndex"
1057+
1058+
if len(indices) == 0:
1059+
taken = np.array([], dtype=self.dtype)
1060+
else:
1061+
ind_max = indices.max()
1062+
if ind_max >= len(self):
1063+
raise IndexError(
1064+
f"index {ind_max} is out of bounds for axis 0 with size {len(self)}"
1065+
)
1066+
ind_min = indices.min()
1067+
if ind_min < -len(self):
1068+
raise IndexError(
1069+
f"index {ind_min} is out of bounds for axis 0 with size {len(self)}"
1070+
)
1071+
taken = indices.astype(self.dtype, casting="safe")
1072+
if ind_min < 0:
1073+
taken %= len(self)
1074+
if self.step != 1:
1075+
taken *= self.step
1076+
if self.start != 0:
1077+
taken += self.start
1078+
1079+
# _constructor so RangeIndex-> Index with an int64 dtype
1080+
return self._constructor._simple_new(taken, name=self.name)

pandas/tests/indexes/ranges/test_indexing.py

+36
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,46 @@ def test_take_fill_value(self):
7676
with pytest.raises(ValueError, match=msg):
7777
idx.take(np.array([1, 0, -5]), fill_value=True)
7878

79+
def test_take_raises_index_error(self):
80+
idx = RangeIndex(1, 4, name="xxx")
81+
7982
msg = "index -5 is out of bounds for (axis 0 with )?size 3"
8083
with pytest.raises(IndexError, match=msg):
8184
idx.take(np.array([1, -5]))
8285

86+
msg = "index -4 is out of bounds for (axis 0 with )?size 3"
87+
with pytest.raises(IndexError, match=msg):
88+
idx.take(np.array([1, -4]))
89+
90+
# no errors
91+
result = idx.take(np.array([1, -3]))
92+
expected = Index([2, 1], dtype=np.int64, name="xxx")
93+
tm.assert_index_equal(result, expected)
94+
95+
def test_take_accepts_empty_array(self):
96+
idx = RangeIndex(1, 4, name="foo")
97+
result = idx.take(np.array([]))
98+
expected = Index([], dtype=np.int64, name="foo")
99+
tm.assert_index_equal(result, expected)
100+
101+
# empty index
102+
idx = RangeIndex(0, name="foo")
103+
result = idx.take(np.array([]))
104+
expected = Index([], dtype=np.int64, name="foo")
105+
tm.assert_index_equal(result, expected)
106+
107+
def test_take_accepts_non_int64_array(self):
108+
idx = RangeIndex(1, 4, name="foo")
109+
result = idx.take(np.array([2, 1], dtype=np.uint32))
110+
expected = Index([3, 2], dtype=np.int64, name="foo")
111+
tm.assert_index_equal(result, expected)
112+
113+
def test_take_when_index_has_step(self):
114+
idx = RangeIndex(1, 11, 3, name="foo") # [1, 4, 7, 10]
115+
result = idx.take(np.array([1, 0, -1, -4]))
116+
expected = Index([4, 1, 10, 1], dtype=np.int64, name="foo")
117+
tm.assert_index_equal(result, expected)
118+
83119

84120
class TestWhere:
85121
def test_where_putmask_range_cast(self):

pandas/tests/indexes/ranges/test_range.py

+18
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,24 @@ def test_cache(self):
295295
expected = np.arange(0, 100, 10, dtype="int64")
296296
tm.assert_numpy_array_equal(idx._cache["_data"], expected)
297297

298+
def test_cache_after_calling_loc_with_array(self):
299+
# GH 53387
300+
# the cache will contain a _constructor key, so it should be tested separately
301+
idx = RangeIndex(0, 100, 10)
302+
df = pd.DataFrame({"a": range(10)}, index=idx)
303+
304+
assert "_data" not in idx._cache
305+
306+
# take is internally called by loc, but it's also tested explicitly
307+
idx.take([3, 0, 1])
308+
assert "_data" not in idx._cache
309+
310+
df.loc[[50]]
311+
assert "_data" not in idx._cache
312+
313+
df.iloc[[5, 6, 7, 8, 9]]
314+
assert "_data" not in idx._cache
315+
298316
def test_is_monotonic(self):
299317
index = RangeIndex(0, 20, 2)
300318
assert index.is_monotonic_increasing is True

0 commit comments

Comments
 (0)