Skip to content

Commit 9cc9053

Browse files
authored
BUG: IntervalArray.insert cast on failure (#40359)
1 parent 60e9189 commit 9cc9053

File tree

4 files changed

+46
-16
lines changed

4 files changed

+46
-16
lines changed

pandas/core/arrays/interval.py

+22
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,28 @@ def putmask(self, mask: np.ndarray, value) -> None:
14871487
self._left.putmask(mask, value_left)
14881488
self._right.putmask(mask, value_right)
14891489

1490+
def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT:
1491+
"""
1492+
Return a new IntervalArray inserting new item at location. Follows
1493+
Python list.append semantics for negative values. Only Interval
1494+
objects and NA can be inserted into an IntervalIndex
1495+
1496+
Parameters
1497+
----------
1498+
loc : int
1499+
item : Interval
1500+
1501+
Returns
1502+
-------
1503+
IntervalArray
1504+
"""
1505+
left_insert, right_insert = self._validate_scalar(item)
1506+
1507+
new_left = self.left.insert(loc, left_insert)
1508+
new_right = self.right.insert(loc, right_insert)
1509+
1510+
return self._shallow_copy(new_left, new_right)
1511+
14901512
def delete(self: IntervalArrayT, loc) -> IntervalArrayT:
14911513
if isinstance(self._left, np.ndarray):
14921514
new_left = np.delete(self._left, loc)

pandas/core/indexes/interval.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -873,11 +873,14 @@ def insert(self, loc, item):
873873
-------
874874
IntervalIndex
875875
"""
876-
left_insert, right_insert = self._data._validate_scalar(item)
876+
try:
877+
result = self._data.insert(loc, item)
878+
except (ValueError, TypeError):
879+
# e.g trying to insert a string
880+
dtype, _ = infer_dtype_from_scalar(item, pandas_dtype=True)
881+
dtype = find_common_type([self.dtype, dtype])
882+
return self.astype(dtype).insert(loc, item)
877883

878-
new_left = self.left.insert(loc, left_insert)
879-
new_right = self.right.insert(loc, right_insert)
880-
result = self._data._shallow_copy(new_left, new_right)
881884
return type(self)._simple_new(result, name=self.name)
882885

883886
# --------------------------------------------------------------------

pandas/tests/indexes/interval/test_interval.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,24 @@ def test_insert(self, data):
194194
tm.assert_index_equal(result, expected)
195195

196196
# invalid type
197+
res = data.insert(1, "foo")
198+
expected = data.astype(object).insert(1, "foo")
199+
tm.assert_index_equal(res, expected)
200+
197201
msg = "can only insert Interval objects and NA into an IntervalArray"
198202
with pytest.raises(TypeError, match=msg):
199-
data.insert(1, "foo")
203+
data._data.insert(1, "foo")
200204

201205
# invalid closed
202206
msg = "'value.closed' is 'left', expected 'right'."
203207
for closed in {"left", "right", "both", "neither"} - {item.closed}:
204208
msg = f"'value.closed' is '{closed}', expected '{item.closed}'."
209+
bad_item = Interval(item.left, item.right, closed=closed)
210+
res = data.insert(1, bad_item)
211+
expected = data.astype(object).insert(1, bad_item)
212+
tm.assert_index_equal(res, expected)
205213
with pytest.raises(ValueError, match=msg):
206-
bad_item = Interval(item.left, item.right, closed=closed)
207-
data.insert(1, bad_item)
214+
data._data.insert(1, bad_item)
208215

209216
# GH 18295 (test missing)
210217
na_idx = IntervalIndex([np.nan], closed=data.closed)
@@ -214,13 +221,15 @@ def test_insert(self, data):
214221
tm.assert_index_equal(result, expected)
215222

216223
if data.left.dtype.kind not in ["m", "M"]:
217-
# trying to insert pd.NaT into a numeric-dtyped Index should cast/raise
224+
# trying to insert pd.NaT into a numeric-dtyped Index should cast
225+
expected = data.astype(object).insert(1, pd.NaT)
226+
218227
msg = "can only insert Interval objects and NA into an IntervalArray"
219228
with pytest.raises(TypeError, match=msg):
220-
result = data.insert(1, pd.NaT)
221-
else:
222-
result = data.insert(1, pd.NaT)
223-
tm.assert_index_equal(result, expected)
229+
data._data.insert(1, pd.NaT)
230+
231+
result = data.insert(1, pd.NaT)
232+
tm.assert_index_equal(result, expected)
224233

225234
def test_is_unique_interval(self, closed):
226235
"""

pandas/tests/indexing/test_loc.py

-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
DatetimeIndex,
2424
Index,
2525
IndexSlice,
26-
IntervalIndex,
2726
MultiIndex,
2827
Period,
2928
Series,
@@ -1680,9 +1679,6 @@ def test_loc_setitem_with_expansion_nonunique_index(self, index, request):
16801679
# GH#40096
16811680
if not len(index):
16821681
return
1683-
if isinstance(index, IntervalIndex):
1684-
mark = pytest.mark.xfail(reason="IntervalIndex raises")
1685-
request.node.add_marker(mark)
16861682

16871683
index = index.repeat(2) # ensure non-unique
16881684
N = len(index)

0 commit comments

Comments
 (0)