Skip to content

Commit a080953

Browse files
authored
REF: avoid accessing index._engine in set_with_engine (#41959)
1 parent 764f2df commit a080953

File tree

8 files changed

+32
-26
lines changed

8 files changed

+32
-26
lines changed

pandas/core/indexes/datetimes.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
Dtype,
3535
DtypeObj,
3636
)
37-
from pandas.errors import InvalidIndexError
3837
from pandas.util._decorators import (
3938
cache_readonly,
4039
doc,
@@ -658,8 +657,7 @@ def get_loc(self, key, method=None, tolerance=None):
658657
-------
659658
loc : int
660659
"""
661-
if not is_scalar(key):
662-
raise InvalidIndexError(key)
660+
self._check_indexing_error(key)
663661

664662
orig_key = key
665663
if is_valid_na_for_dtype(key, self.dtype):

pandas/core/indexes/interval.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -613,9 +613,7 @@ def get_loc(
613613
0
614614
"""
615615
self._check_indexing_method(method)
616-
617-
if not is_scalar(key):
618-
raise InvalidIndexError(key)
616+
self._check_indexing_error(key)
619617

620618
if isinstance(key, Interval):
621619
if self.closed != key.closed:

pandas/core/indexes/period.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,12 @@
2727
Dtype,
2828
DtypeObj,
2929
)
30-
from pandas.errors import InvalidIndexError
3130
from pandas.util._decorators import doc
3231

3332
from pandas.core.dtypes.common import (
3433
is_datetime64_any_dtype,
3534
is_float,
3635
is_integer,
37-
is_scalar,
3836
pandas_dtype,
3937
)
4038
from pandas.core.dtypes.dtypes import PeriodDtype
@@ -411,9 +409,7 @@ def get_loc(self, key, method=None, tolerance=None):
411409
"""
412410
orig_key = key
413411

414-
if not is_scalar(key):
415-
raise InvalidIndexError(key)
416-
412+
self._check_indexing_error(key)
417413
if isinstance(key, str):
418414

419415
try:

pandas/core/indexes/range.py

+1
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def get_loc(self, key, method=None, tolerance=None):
385385
return self._range.index(new_key)
386386
except ValueError as err:
387387
raise KeyError(key) from err
388+
self._check_indexing_error(key)
388389
raise KeyError(key)
389390
return super().get_loc(key, method=method, tolerance=tolerance)
390391

pandas/core/indexes/timedeltas.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
DtypeObj,
1414
Optional,
1515
)
16-
from pandas.errors import InvalidIndexError
1716

1817
from pandas.core.dtypes.common import (
1918
TD64NS_DTYPE,
@@ -170,8 +169,7 @@ def get_loc(self, key, method=None, tolerance=None):
170169
-------
171170
loc : int, slice, or ndarray[int]
172171
"""
173-
if not is_scalar(key):
174-
raise InvalidIndexError(key)
172+
self._check_indexing_error(key)
175173

176174
try:
177175
key = self._data._validate_scalar(key, unbox=False)

pandas/core/series.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ def __setitem__(self, key, value) -> None:
10721072
# GH#12862 adding a new key to the Series
10731073
self.loc[key] = value
10741074

1075-
except TypeError as err:
1075+
except (InvalidIndexError, TypeError) as err:
10761076
if isinstance(key, tuple) and not isinstance(self.index, MultiIndex):
10771077
raise KeyError(
10781078
"key of type tuple not found and not a MultiIndex"
@@ -1094,8 +1094,7 @@ def __setitem__(self, key, value) -> None:
10941094
self._maybe_update_cacher()
10951095

10961096
def _set_with_engine(self, key, value) -> None:
1097-
# fails with AttributeError for IntervalIndex
1098-
loc = self.index._engine.get_loc(key)
1097+
loc = self.index.get_loc(key)
10991098
# error: Argument 1 to "validate_numeric_casting" has incompatible type
11001099
# "Union[dtype, ExtensionDtype]"; expected "dtype"
11011100
validate_numeric_casting(self.dtype, value) # type: ignore[arg-type]
@@ -1112,6 +1111,9 @@ def _set_with(self, key, value):
11121111

11131112
if is_scalar(key):
11141113
key = [key]
1114+
elif is_iterator(key):
1115+
# Without this, the call to infer_dtype will consume the generator
1116+
key = list(key)
11151117

11161118
if isinstance(key, Index):
11171119
key_type = key.inferred_type

pandas/tests/indexes/test_indexing.py

+22
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
IntervalIndex,
2727
MultiIndex,
2828
PeriodIndex,
29+
RangeIndex,
2930
Series,
3031
TimedeltaIndex,
3132
UInt64Index,
@@ -181,6 +182,27 @@ def test_get_value(self, index):
181182
tm.assert_almost_equal(result, values[67])
182183

183184

185+
class TestGetLoc:
186+
def test_get_loc_non_hashable(self, index):
187+
# MultiIndex and Index raise TypeError, others InvalidIndexError
188+
189+
with pytest.raises((TypeError, InvalidIndexError), match="slice"):
190+
index.get_loc(slice(0, 1))
191+
192+
def test_get_loc_generator(self, index):
193+
194+
exc = KeyError
195+
if isinstance(
196+
index,
197+
(DatetimeIndex, TimedeltaIndex, PeriodIndex, RangeIndex, IntervalIndex),
198+
):
199+
# TODO: make these more consistent?
200+
exc = InvalidIndexError
201+
with pytest.raises(exc, match="generator object"):
202+
# MultiIndex specifically checks for generator; others for scalar
203+
index.get_loc(x for x in range(5))
204+
205+
184206
class TestGetIndexer:
185207
def test_get_indexer_base(self, index):
186208

pandas/tests/indexing/test_indexing.py

-9
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,6 @@ def test_setitem_ndarray_3d(self, index, frame_or_series, indexer_sli):
113113
if indexer_sli is tm.iloc:
114114
err = ValueError
115115
msg = f"Cannot set values with ndim > {obj.ndim}"
116-
elif (
117-
isinstance(index, pd.IntervalIndex)
118-
and indexer_sli is tm.setitem
119-
and obj.ndim == 1
120-
):
121-
err = AttributeError
122-
msg = (
123-
"'pandas._libs.interval.IntervalTree' object has no attribute 'get_loc'"
124-
)
125116
else:
126117
err = ValueError
127118
msg = "|".join(

0 commit comments

Comments
 (0)