Skip to content

Commit 9065c18

Browse files
PERF: Do not init cache in RangeIndex.take
Improve performance when passing an array to RangeIndex.take, DataFrame.loc, or DataFrame.iloc and the DataFrame is using a RangeIndex
1 parent 8cf4ab4 commit 9065c18

File tree

5 files changed

+100
-1
lines changed

5 files changed

+100
-1
lines changed

doc/source/whatsnew/v2.0.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Bug fixes
4646

4747
Other
4848
~~~~~
49+
- Improved performance when passing an array to :meth:`RangeIndex.take`, :meth:`DataFrame.loc`, or :meth:`DataFrame.iloc` and the DataFrame is using a RangeIndex (:issue:`53387`)
4950
- Raised a better error message when calling :func:`Series.dt.to_pydatetime` with :class:`ArrowDtype` with ``pyarrow.date32`` or ``pyarrow.date64`` type (:issue:`52812`)
5051

5152
.. ---------------------------------------------------------------------------

pandas/core/indexes/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,6 @@ def take(
11621162
taken = values.take(
11631163
indices, allow_fill=allow_fill, fill_value=self._na_value
11641164
)
1165-
# _constructor so RangeIndex-> Index with an int64 dtype
11661165
return self._constructor._simple_new(taken, name=self.name)
11671166

11681167
@final
@@ -5537,6 +5536,10 @@ def equals(self, other: Any) -> bool:
55375536
if not isinstance(other, Index):
55385537
return False
55395538

5539+
if len(self) != len(other):
5540+
# quickly return if the lengths are different
5541+
return False
5542+
55405543
if is_object_dtype(self.dtype) and not is_object_dtype(other.dtype):
55415544
# if other is not object, use other's logic for coercion
55425545
return other.equals(self)

pandas/core/indexes/range.py

+41
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
if TYPE_CHECKING:
5151
from pandas._typing import (
52+
Axis,
5253
Dtype,
5354
NaPosition,
5455
Self,
@@ -1037,3 +1038,43 @@ def _arith_method(self, other, op):
10371038
except (ValueError, TypeError, ZeroDivisionError):
10381039
# test_arithmetic_explicit_conversions
10391040
return super()._arith_method(other, op)
1041+
1042+
def take(
1043+
self,
1044+
indices,
1045+
axis: Axis = 0,
1046+
allow_fill: bool = True,
1047+
fill_value=None,
1048+
**kwargs,
1049+
):
1050+
if kwargs:
1051+
nv.validate_take((), kwargs)
1052+
if is_scalar(indices):
1053+
raise TypeError("Expected indices to be array-like")
1054+
indices = ensure_platform_int(indices)
1055+
allow_fill = self._maybe_disallow_fill(allow_fill, fill_value, indices)
1056+
assert allow_fill is False, "allow_fill isn't supported by RangeIndex"
1057+
1058+
if len(indices) == 0:
1059+
taken = np.array([], dtype=self.dtype)
1060+
else:
1061+
ind_max = indices.max()
1062+
if ind_max >= len(self):
1063+
raise IndexError(
1064+
f"index {ind_max} is out of bounds for axis 0 with size {len(self)}"
1065+
)
1066+
ind_min = indices.min()
1067+
if ind_min < -len(self):
1068+
raise IndexError(
1069+
f"index {ind_min} is out of bounds for axis 0 with size {len(self)}"
1070+
)
1071+
taken = indices.astype(self.dtype, casting="safe")
1072+
if ind_min < 0:
1073+
taken %= len(self)
1074+
if self.step != 1:
1075+
taken *= self.step
1076+
if self.start != 0:
1077+
taken += self.start
1078+
1079+
# _constructor so RangeIndex-> Index with an int64 dtype
1080+
return self._constructor._simple_new(taken, name=self.name)

pandas/tests/indexes/ranges/test_indexing.py

+36
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,46 @@ def test_take_fill_value(self):
7676
with pytest.raises(ValueError, match=msg):
7777
idx.take(np.array([1, 0, -5]), fill_value=True)
7878

79+
def test_take_raises_index_error(self):
80+
idx = RangeIndex(1, 4, name="xxx")
81+
7982
msg = "index -5 is out of bounds for (axis 0 with )?size 3"
8083
with pytest.raises(IndexError, match=msg):
8184
idx.take(np.array([1, -5]))
8285

86+
msg = "index -4 is out of bounds for (axis 0 with )?size 3"
87+
with pytest.raises(IndexError, match=msg):
88+
idx.take(np.array([1, -4]))
89+
90+
# no errors
91+
result = idx.take(np.array([1, -3]))
92+
expected = Index([2, 1], dtype=np.int64, name="xxx")
93+
tm.assert_index_equal(result, expected)
94+
95+
def test_take_accepts_empty_array(self):
96+
idx = RangeIndex(1, 4, name="foo")
97+
result = idx.take(np.array([]))
98+
expected = Index([], dtype=np.int64, name="foo")
99+
tm.assert_index_equal(result, expected)
100+
101+
# empty index
102+
idx = RangeIndex(0, name="foo")
103+
result = idx.take(np.array([]))
104+
expected = Index([], dtype=np.int64, name="foo")
105+
tm.assert_index_equal(result, expected)
106+
107+
def test_take_accepts_non_int64_array(self):
108+
idx = RangeIndex(1, 4, name="foo")
109+
result = idx.take(np.array([2, 1], dtype=np.uint32))
110+
expected = Index([3, 2], dtype=np.int64, name="foo")
111+
tm.assert_index_equal(result, expected)
112+
113+
def test_take_when_index_has_step(self):
114+
idx = RangeIndex(1, 11, 3, name="foo") # [1, 4, 7, 10]
115+
result = idx.take(np.array([1, 0, -1, -4]))
116+
expected = Index([4, 1, 10, 1], dtype=np.int64, name="foo")
117+
tm.assert_index_equal(result, expected)
118+
83119

84120
class TestWhere:
85121
def test_where_putmask_range_cast(self):

pandas/tests/indexes/ranges/test_range.py

+18
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,24 @@ def test_cache(self):
295295
expected = np.arange(0, 100, 10, dtype="int64")
296296
tm.assert_numpy_array_equal(idx._cache["_data"], expected)
297297

298+
def test_cache_after_calling_loc_with_array(self):
299+
# GH 53387
300+
# the cache will contain a _constructor key, so it should be tested separately
301+
idx = RangeIndex(0, 100, 10)
302+
df = pd.DataFrame({"a": range(10)}, index=idx)
303+
304+
assert "_data" not in idx._cache
305+
306+
# take is internally called by loc, but it's also tested explicitly
307+
idx.take([3, 0, 1])
308+
assert "_data" not in idx._cache
309+
310+
df.loc[[50]]
311+
assert "_data" not in idx._cache
312+
313+
df.iloc[[5, 6, 7, 8, 9]]
314+
assert "_data" not in idx._cache
315+
298316
def test_is_monotonic(self):
299317
index = RangeIndex(0, 20, 2)
300318
assert index.is_monotonic_increasing is True

0 commit comments

Comments
 (0)