From d534b73ead028d40b5d7e517c48235439582e544 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 10 Mar 2021 13:37:49 -0800 Subject: [PATCH] BUG: IntervalArray.insert cast on failure --- pandas/core/arrays/interval.py | 22 ++++++++++++++++ pandas/core/indexes/interval.py | 11 +++++--- .../tests/indexes/interval/test_interval.py | 25 +++++++++++++------ pandas/tests/indexing/test_loc.py | 4 --- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index f192a34514390..3ee550620e73a 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1483,6 +1483,28 @@ def putmask(self, mask: np.ndarray, value) -> None: self._left.putmask(mask, value_left) self._right.putmask(mask, value_right) + def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT: + """ + Return a new IntervalArray inserting new item at location. Follows + Python list.append semantics for negative values. Only Interval + objects and NA can be inserted into an IntervalIndex + + Parameters + ---------- + loc : int + item : Interval + + Returns + ------- + IntervalArray + """ + left_insert, right_insert = self._validate_scalar(item) + + new_left = self.left.insert(loc, left_insert) + new_right = self.right.insert(loc, right_insert) + + return self._shallow_copy(new_left, new_right) + def delete(self: IntervalArrayT, loc) -> IntervalArrayT: if isinstance(self._left, np.ndarray): new_left = np.delete(self._left, loc) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index ad512b8393166..10fdc642ba7ce 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -873,11 +873,14 @@ def insert(self, loc, item): ------- IntervalIndex """ - left_insert, right_insert = self._data._validate_scalar(item) + try: + result = self._data.insert(loc, item) + except (ValueError, TypeError): + # e.g trying to insert a string + dtype, _ = infer_dtype_from_scalar(item, pandas_dtype=True) + dtype = find_common_type([self.dtype, dtype]) + return self.astype(dtype).insert(loc, item) - new_left = self.left.insert(loc, left_insert) - new_right = self.right.insert(loc, right_insert) - result = self._data._shallow_copy(new_left, new_right) return type(self)._simple_new(result, name=self.name) # -------------------------------------------------------------------- diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 02ef3cb0e2afb..cd61fcaa835a4 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -194,17 +194,24 @@ def test_insert(self, data): tm.assert_index_equal(result, expected) # invalid type + res = data.insert(1, "foo") + expected = data.astype(object).insert(1, "foo") + tm.assert_index_equal(res, expected) + msg = "can only insert Interval objects and NA into an IntervalArray" with pytest.raises(TypeError, match=msg): - data.insert(1, "foo") + data._data.insert(1, "foo") # invalid closed msg = "'value.closed' is 'left', expected 'right'." for closed in {"left", "right", "both", "neither"} - {item.closed}: msg = f"'value.closed' is '{closed}', expected '{item.closed}'." + bad_item = Interval(item.left, item.right, closed=closed) + res = data.insert(1, bad_item) + expected = data.astype(object).insert(1, bad_item) + tm.assert_index_equal(res, expected) with pytest.raises(ValueError, match=msg): - bad_item = Interval(item.left, item.right, closed=closed) - data.insert(1, bad_item) + data._data.insert(1, bad_item) # GH 18295 (test missing) na_idx = IntervalIndex([np.nan], closed=data.closed) @@ -214,13 +221,15 @@ def test_insert(self, data): tm.assert_index_equal(result, expected) if data.left.dtype.kind not in ["m", "M"]: - # trying to insert pd.NaT into a numeric-dtyped Index should cast/raise + # trying to insert pd.NaT into a numeric-dtyped Index should cast + expected = data.astype(object).insert(1, pd.NaT) + msg = "can only insert Interval objects and NA into an IntervalArray" with pytest.raises(TypeError, match=msg): - result = data.insert(1, pd.NaT) - else: - result = data.insert(1, pd.NaT) - tm.assert_index_equal(result, expected) + data._data.insert(1, pd.NaT) + + result = data.insert(1, pd.NaT) + tm.assert_index_equal(result, expected) def test_is_unique_interval(self, closed): """ diff --git a/pandas/tests/indexing/test_loc.py b/pandas/tests/indexing/test_loc.py index 9dbce283d2a8f..c7e9b3eb5b852 100644 --- a/pandas/tests/indexing/test_loc.py +++ b/pandas/tests/indexing/test_loc.py @@ -23,7 +23,6 @@ DatetimeIndex, Index, IndexSlice, - IntervalIndex, MultiIndex, Period, Series, @@ -1680,9 +1679,6 @@ def test_loc_setitem_with_expansion_nonunique_index(self, index, request): # GH#40096 if not len(index): return - if isinstance(index, IntervalIndex): - mark = pytest.mark.xfail(reason="IntervalIndex raises") - request.node.add_marker(mark) index = index.repeat(2) # ensure non-unique N = len(index)