diff --git a/doc/source/reference/extensions.rst b/doc/source/reference/extensions.rst index e2e8c94ef8fc6..ce8d8d5c2ca10 100644 --- a/doc/source/reference/extensions.rst +++ b/doc/source/reference/extensions.rst @@ -48,6 +48,7 @@ objects. api.extensions.ExtensionArray.equals api.extensions.ExtensionArray.factorize api.extensions.ExtensionArray.fillna + api.extensions.ExtensionArray.insert api.extensions.ExtensionArray.isin api.extensions.ExtensionArray.isna api.extensions.ExtensionArray.ravel diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 3769c686da029..cf9820c3aa8f8 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -31,6 +31,7 @@ from pandas.util._validators import ( validate_bool_kwarg, validate_fillna_kwargs, + validate_insert_loc, ) from pandas.core.dtypes.common import is_dtype_equal @@ -359,6 +360,8 @@ def insert( ------- type(self) """ + loc = validate_insert_loc(loc, len(self)) + code = self._validate_scalar(item) new_vals = np.concatenate( diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index bf54f7166e14d..9b25a1b5abccd 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -47,6 +47,7 @@ from pandas.util._validators import ( validate_bool_kwarg, validate_fillna_kwargs, + validate_insert_loc, ) from pandas.core.dtypes.cast import maybe_cast_to_extension_array @@ -123,6 +124,7 @@ class ExtensionArray: factorize fillna equals + insert isin isna ravel @@ -1388,6 +1390,34 @@ def delete(self: ExtensionArrayT, loc: PositionalIndexer) -> ExtensionArrayT: indexer = np.delete(np.arange(len(self)), loc) return self.take(indexer) + def insert(self: ExtensionArrayT, loc: int, item) -> ExtensionArrayT: + """ + Insert an item at the given position. + + Parameters + ---------- + loc : int + item : scalar-like + + Returns + ------- + same type as self + + Notes + ----- + This method should be both type and dtype-preserving. If the item + cannot be held in an array of this type/dtype, either ValueError or + TypeError should be raised. + + The default implementation relies on _from_sequence to raise on invalid + items. + """ + loc = validate_insert_loc(loc, len(self)) + + item_arr = type(self)._from_sequence([item], dtype=self.dtype) + + return type(self)._concat_same_type([self[:loc], item_arr, self[loc:]]) + @classmethod def _empty(cls, shape: Shape, dtype: ExtensionDtype): """ diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py index 12220e825aed4..88b26dcc4d707 100644 --- a/pandas/tests/arithmetic/test_interval.py +++ b/pandas/tests/arithmetic/test_interval.py @@ -28,16 +28,16 @@ (Index([0, 2, 4, 4]), Index([1, 3, 5, 8])), (Index([0.0, 1.0, 2.0, np.nan]), Index([1.0, 2.0, 3.0, np.nan])), ( - timedelta_range("0 days", periods=3).insert(4, pd.NaT), - timedelta_range("1 day", periods=3).insert(4, pd.NaT), + timedelta_range("0 days", periods=3).insert(3, pd.NaT), + timedelta_range("1 day", periods=3).insert(3, pd.NaT), ), ( - date_range("20170101", periods=3).insert(4, pd.NaT), - date_range("20170102", periods=3).insert(4, pd.NaT), + date_range("20170101", periods=3).insert(3, pd.NaT), + date_range("20170102", periods=3).insert(3, pd.NaT), ), ( - date_range("20170101", periods=3, tz="US/Eastern").insert(4, pd.NaT), - date_range("20170102", periods=3, tz="US/Eastern").insert(4, pd.NaT), + date_range("20170101", periods=3, tz="US/Eastern").insert(3, pd.NaT), + date_range("20170102", periods=3, tz="US/Eastern").insert(3, pd.NaT), ), ], ids=lambda x: str(x[0].dtype), diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index d390d4b5d8143..c96e2fb49e397 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -511,6 +511,48 @@ def test_delete(self, data): expected = data._concat_same_type([data[[0]], data[[2]], data[4:]]) self.assert_extension_array_equal(result, expected) + def test_insert(self, data): + # insert at the beginning + result = data[1:].insert(0, data[0]) + self.assert_extension_array_equal(result, data) + + result = data[1:].insert(-len(data[1:]), data[0]) + self.assert_extension_array_equal(result, data) + + # insert at the middle + result = data[:-1].insert(4, data[-1]) + + taker = np.arange(len(data)) + taker[5:] = taker[4:-1] + taker[4] = len(data) - 1 + expected = data.take(taker) + self.assert_extension_array_equal(result, expected) + + def test_insert_invalid(self, data, invalid_scalar): + item = invalid_scalar + + with pytest.raises((TypeError, ValueError)): + data.insert(0, item) + + with pytest.raises((TypeError, ValueError)): + data.insert(4, item) + + with pytest.raises((TypeError, ValueError)): + data.insert(len(data) - 1, item) + + def test_insert_invalid_loc(self, data): + ub = len(data) + + with pytest.raises(IndexError): + data.insert(ub + 1, data[0]) + + with pytest.raises(IndexError): + data.insert(-ub - 1, data[0]) + + with pytest.raises(TypeError): + # we expect TypeError here instead of IndexError to match np.insert + data.insert(1.5, data[0]) + @pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame]) def test_equals(self, data, na_value, as_series, box): data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype) diff --git a/pandas/tests/extension/conftest.py b/pandas/tests/extension/conftest.py index 1942d737780da..3827ba234cfd8 100644 --- a/pandas/tests/extension/conftest.py +++ b/pandas/tests/extension/conftest.py @@ -181,3 +181,15 @@ def as_array(request): Boolean fixture to support ExtensionDtype _from_sequence method testing. """ return request.param + + +@pytest.fixture +def invalid_scalar(data): + """ + A scalar that *cannot* be held by this ExtensionArray. + + The default should work for most subclasses, but is not guaranteed. + + If the array can hold any item (i.e. object dtype), then use pytest.skip. + """ + return object.__new__(object) diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index 7be776819e399..0e3e26e7e9500 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -265,6 +265,18 @@ def test_searchsorted(self, data_for_sorting, as_series): def test_diff(self, data, periods): return super().test_diff(data, periods) + def test_insert(self, data, request): + if data.dtype.numpy_dtype == object: + mark = pytest.mark.xfail(reason="Dimension mismatch in np.concatenate") + request.node.add_marker(mark) + + super().test_insert(data) + + @skip_nested + def test_insert_invalid(self, data, invalid_scalar): + # PandasArray[object] can hold anything, so skip + super().test_insert_invalid(data, invalid_scalar) + class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests): divmod_exc = None diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index af86c359c4c00..06b07968f949e 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -160,6 +160,13 @@ def test_value_counts(self, all_data, dropna): def test_value_counts_with_normalize(self, data): pass + def test_insert_invalid(self, data, invalid_scalar, request): + if data.dtype.storage == "pyarrow": + mark = pytest.mark.xfail(reason="casts invalid_scalar to string") + request.node.add_marker(mark) + + super().test_insert_invalid(data, invalid_scalar) + class TestCasting(base.BaseCastingTests): pass diff --git a/pandas/tests/indexes/timedeltas/methods/test_insert.py b/pandas/tests/indexes/timedeltas/methods/test_insert.py index 809d21db805e0..3af4b6b47fa2f 100644 --- a/pandas/tests/indexes/timedeltas/methods/test_insert.py +++ b/pandas/tests/indexes/timedeltas/methods/test_insert.py @@ -136,8 +136,8 @@ def test_insert_empty(self): result = idx[:0].insert(0, td) assert result.freq == "D" - result = idx[:0].insert(1, td) - assert result.freq == "D" + with pytest.raises(IndexError, match="loc must be an integer between"): + result = idx[:0].insert(1, td) - result = idx[:0].insert(-1, td) - assert result.freq == "D" + with pytest.raises(IndexError, match="loc must be an integer between"): + result = idx[:0].insert(-1, td) diff --git a/pandas/util/_validators.py b/pandas/util/_validators.py index 7e03e3ceea11d..f8bd1ec7bc96a 100644 --- a/pandas/util/_validators.py +++ b/pandas/util/_validators.py @@ -12,7 +12,10 @@ import numpy as np -from pandas.core.dtypes.common import is_bool +from pandas.core.dtypes.common import ( + is_bool, + is_integer, +) def _check_arg_length(fname, args, max_fname_arg_count, compat_args): @@ -494,3 +497,21 @@ def validate_inclusive(inclusive: str | None) -> tuple[bool, bool]: ) return left_right_inclusive + + +def validate_insert_loc(loc: int, length: int) -> int: + """ + Check that we have an integer between -length and length, inclusive. + + Standardize negative loc to within [0, length]. + + The exceptions we raise on failure match np.insert. + """ + if not is_integer(loc): + raise TypeError(f"loc must be an integer between -{length} and {length}") + + if loc < 0: + loc += length + if not 0 <= loc <= length: + raise IndexError(f"loc must be an integer between -{length} and {length}") + return loc