Skip to content

BUG: IntervalArray.insert cast on failure #40359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,28 @@ def putmask(self, mask: np.ndarray, value) -> None:
self._left.putmask(mask, value_left)
self._right.putmask(mask, value_right)

def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT:
"""
Return a new IntervalArray inserting new item at location. Follows
Python list.append semantics for negative values. Only Interval
objects and NA can be inserted into an IntervalIndex

Parameters
----------
loc : int
item : Interval

Returns
-------
IntervalArray
"""
left_insert, right_insert = self._validate_scalar(item)

new_left = self.left.insert(loc, left_insert)
new_right = self.right.insert(loc, right_insert)

return self._shallow_copy(new_left, new_right)

def delete(self: IntervalArrayT, loc) -> IntervalArrayT:
if isinstance(self._left, np.ndarray):
new_left = np.delete(self._left, loc)
Expand Down
11 changes: 7 additions & 4 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,11 +873,14 @@ def insert(self, loc, item):
-------
IntervalIndex
"""
left_insert, right_insert = self._data._validate_scalar(item)
try:
result = self._data.insert(loc, item)
except (ValueError, TypeError):
# e.g trying to insert a string
dtype, _ = infer_dtype_from_scalar(item, pandas_dtype=True)
dtype = find_common_type([self.dtype, dtype])
return self.astype(dtype).insert(loc, item)

new_left = self.left.insert(loc, left_insert)
new_right = self.right.insert(loc, right_insert)
result = self._data._shallow_copy(new_left, new_right)
return type(self)._simple_new(result, name=self.name)

# --------------------------------------------------------------------
Expand Down
25 changes: 17 additions & 8 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,24 @@ def test_insert(self, data):
tm.assert_index_equal(result, expected)

# invalid type
res = data.insert(1, "foo")
expected = data.astype(object).insert(1, "foo")
tm.assert_index_equal(res, expected)

msg = "can only insert Interval objects and NA into an IntervalArray"
with pytest.raises(TypeError, match=msg):
data.insert(1, "foo")
data._data.insert(1, "foo")

# invalid closed
msg = "'value.closed' is 'left', expected 'right'."
for closed in {"left", "right", "both", "neither"} - {item.closed}:
msg = f"'value.closed' is '{closed}', expected '{item.closed}'."
bad_item = Interval(item.left, item.right, closed=closed)
res = data.insert(1, bad_item)
expected = data.astype(object).insert(1, bad_item)
tm.assert_index_equal(res, expected)
with pytest.raises(ValueError, match=msg):
bad_item = Interval(item.left, item.right, closed=closed)
data.insert(1, bad_item)
data._data.insert(1, bad_item)

# GH 18295 (test missing)
na_idx = IntervalIndex([np.nan], closed=data.closed)
Expand All @@ -214,13 +221,15 @@ def test_insert(self, data):
tm.assert_index_equal(result, expected)

if data.left.dtype.kind not in ["m", "M"]:
# trying to insert pd.NaT into a numeric-dtyped Index should cast/raise
# trying to insert pd.NaT into a numeric-dtyped Index should cast
expected = data.astype(object).insert(1, pd.NaT)

msg = "can only insert Interval objects and NA into an IntervalArray"
with pytest.raises(TypeError, match=msg):
result = data.insert(1, pd.NaT)
else:
result = data.insert(1, pd.NaT)
tm.assert_index_equal(result, expected)
data._data.insert(1, pd.NaT)

result = data.insert(1, pd.NaT)
tm.assert_index_equal(result, expected)

def test_is_unique_interval(self, closed):
"""
Expand Down
4 changes: 0 additions & 4 deletions pandas/tests/indexing/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
DatetimeIndex,
Index,
IndexSlice,
IntervalIndex,
MultiIndex,
Period,
Series,
Expand Down Expand Up @@ -1680,9 +1679,6 @@ def test_loc_setitem_with_expansion_nonunique_index(self, index, request):
# GH#40096
if not len(index):
return
if isinstance(index, IntervalIndex):
mark = pytest.mark.xfail(reason="IntervalIndex raises")
request.node.add_marker(mark)

index = index.repeat(2) # ensure non-unique
N = len(index)
Expand Down