Skip to content

Commit 8969a58

Browse files
authored
ENH: ExtensionArray.insert (#44138)
1 parent dc3c4b7 commit 8969a58

File tree

10 files changed

+139
-11
lines changed

10 files changed

+139
-11
lines changed

doc/source/reference/extensions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ objects.
4848
api.extensions.ExtensionArray.equals
4949
api.extensions.ExtensionArray.factorize
5050
api.extensions.ExtensionArray.fillna
51+
api.extensions.ExtensionArray.insert
5152
api.extensions.ExtensionArray.isin
5253
api.extensions.ExtensionArray.isna
5354
api.extensions.ExtensionArray.ravel

pandas/core/arrays/_mixins.py

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pandas.util._validators import (
3232
validate_bool_kwarg,
3333
validate_fillna_kwargs,
34+
validate_insert_loc,
3435
)
3536

3637
from pandas.core.dtypes.common import is_dtype_equal
@@ -359,6 +360,8 @@ def insert(
359360
-------
360361
type(self)
361362
"""
363+
loc = validate_insert_loc(loc, len(self))
364+
362365
code = self._validate_scalar(item)
363366

364367
new_vals = np.concatenate(

pandas/core/arrays/base.py

+30
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from pandas.util._validators import (
4848
validate_bool_kwarg,
4949
validate_fillna_kwargs,
50+
validate_insert_loc,
5051
)
5152

5253
from pandas.core.dtypes.cast import maybe_cast_to_extension_array
@@ -123,6 +124,7 @@ class ExtensionArray:
123124
factorize
124125
fillna
125126
equals
127+
insert
126128
isin
127129
isna
128130
ravel
@@ -1381,6 +1383,34 @@ def delete(self: ExtensionArrayT, loc: PositionalIndexer) -> ExtensionArrayT:
13811383
indexer = np.delete(np.arange(len(self)), loc)
13821384
return self.take(indexer)
13831385

1386+
def insert(self: ExtensionArrayT, loc: int, item) -> ExtensionArrayT:
1387+
"""
1388+
Insert an item at the given position.
1389+
1390+
Parameters
1391+
----------
1392+
loc : int
1393+
item : scalar-like
1394+
1395+
Returns
1396+
-------
1397+
same type as self
1398+
1399+
Notes
1400+
-----
1401+
This method should be both type and dtype-preserving. If the item
1402+
cannot be held in an array of this type/dtype, either ValueError or
1403+
TypeError should be raised.
1404+
1405+
The default implementation relies on _from_sequence to raise on invalid
1406+
items.
1407+
"""
1408+
loc = validate_insert_loc(loc, len(self))
1409+
1410+
item_arr = type(self)._from_sequence([item], dtype=self.dtype)
1411+
1412+
return type(self)._concat_same_type([self[:loc], item_arr, self[loc:]])
1413+
13841414
@classmethod
13851415
def _empty(cls, shape: Shape, dtype: ExtensionDtype):
13861416
"""

pandas/tests/arithmetic/test_interval.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@
2828
(Index([0, 2, 4, 4]), Index([1, 3, 5, 8])),
2929
(Index([0.0, 1.0, 2.0, np.nan]), Index([1.0, 2.0, 3.0, np.nan])),
3030
(
31-
timedelta_range("0 days", periods=3).insert(4, pd.NaT),
32-
timedelta_range("1 day", periods=3).insert(4, pd.NaT),
31+
timedelta_range("0 days", periods=3).insert(3, pd.NaT),
32+
timedelta_range("1 day", periods=3).insert(3, pd.NaT),
3333
),
3434
(
35-
date_range("20170101", periods=3).insert(4, pd.NaT),
36-
date_range("20170102", periods=3).insert(4, pd.NaT),
35+
date_range("20170101", periods=3).insert(3, pd.NaT),
36+
date_range("20170102", periods=3).insert(3, pd.NaT),
3737
),
3838
(
39-
date_range("20170101", periods=3, tz="US/Eastern").insert(4, pd.NaT),
40-
date_range("20170102", periods=3, tz="US/Eastern").insert(4, pd.NaT),
39+
date_range("20170101", periods=3, tz="US/Eastern").insert(3, pd.NaT),
40+
date_range("20170102", periods=3, tz="US/Eastern").insert(3, pd.NaT),
4141
),
4242
],
4343
ids=lambda x: str(x[0].dtype),

pandas/tests/extension/base/methods.py

+42
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,48 @@ def test_delete(self, data):
511511
expected = data._concat_same_type([data[[0]], data[[2]], data[4:]])
512512
self.assert_extension_array_equal(result, expected)
513513

514+
def test_insert(self, data):
515+
# insert at the beginning
516+
result = data[1:].insert(0, data[0])
517+
self.assert_extension_array_equal(result, data)
518+
519+
result = data[1:].insert(-len(data[1:]), data[0])
520+
self.assert_extension_array_equal(result, data)
521+
522+
# insert at the middle
523+
result = data[:-1].insert(4, data[-1])
524+
525+
taker = np.arange(len(data))
526+
taker[5:] = taker[4:-1]
527+
taker[4] = len(data) - 1
528+
expected = data.take(taker)
529+
self.assert_extension_array_equal(result, expected)
530+
531+
def test_insert_invalid(self, data, invalid_scalar):
532+
item = invalid_scalar
533+
534+
with pytest.raises((TypeError, ValueError)):
535+
data.insert(0, item)
536+
537+
with pytest.raises((TypeError, ValueError)):
538+
data.insert(4, item)
539+
540+
with pytest.raises((TypeError, ValueError)):
541+
data.insert(len(data) - 1, item)
542+
543+
def test_insert_invalid_loc(self, data):
544+
ub = len(data)
545+
546+
with pytest.raises(IndexError):
547+
data.insert(ub + 1, data[0])
548+
549+
with pytest.raises(IndexError):
550+
data.insert(-ub - 1, data[0])
551+
552+
with pytest.raises(TypeError):
553+
# we expect TypeError here instead of IndexError to match np.insert
554+
data.insert(1.5, data[0])
555+
514556
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
515557
def test_equals(self, data, na_value, as_series, box):
516558
data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype)

pandas/tests/extension/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,15 @@ def as_array(request):
181181
Boolean fixture to support ExtensionDtype _from_sequence method testing.
182182
"""
183183
return request.param
184+
185+
186+
@pytest.fixture
187+
def invalid_scalar(data):
188+
"""
189+
A scalar that *cannot* be held by this ExtensionArray.
190+
191+
The default should work for most subclasses, but is not guaranteed.
192+
193+
If the array can hold any item (i.e. object dtype), then use pytest.skip.
194+
"""
195+
return object.__new__(object)

pandas/tests/extension/test_numpy.py

+12
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,18 @@ def test_searchsorted(self, data_for_sorting, as_series):
265265
def test_diff(self, data, periods):
266266
return super().test_diff(data, periods)
267267

268+
def test_insert(self, data, request):
269+
if data.dtype.numpy_dtype == object:
270+
mark = pytest.mark.xfail(reason="Dimension mismatch in np.concatenate")
271+
request.node.add_marker(mark)
272+
273+
super().test_insert(data)
274+
275+
@skip_nested
276+
def test_insert_invalid(self, data, invalid_scalar):
277+
# PandasArray[object] can hold anything, so skip
278+
super().test_insert_invalid(data, invalid_scalar)
279+
268280

269281
class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests):
270282
divmod_exc = None

pandas/tests/extension/test_string.py

+7
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def test_value_counts(self, all_data, dropna):
160160
def test_value_counts_with_normalize(self, data):
161161
pass
162162

163+
def test_insert_invalid(self, data, invalid_scalar, request):
164+
if data.dtype.storage == "pyarrow":
165+
mark = pytest.mark.xfail(reason="casts invalid_scalar to string")
166+
request.node.add_marker(mark)
167+
168+
super().test_insert_invalid(data, invalid_scalar)
169+
163170

164171
class TestCasting(base.BaseCastingTests):
165172
pass

pandas/tests/indexes/timedeltas/methods/test_insert.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def test_insert_empty(self):
139139
result = idx[:0].insert(0, td)
140140
assert result.freq == "D"
141141

142-
result = idx[:0].insert(1, td)
143-
assert result.freq == "D"
142+
with pytest.raises(IndexError, match="loc must be an integer between"):
143+
result = idx[:0].insert(1, td)
144144

145-
result = idx[:0].insert(-1, td)
146-
assert result.freq == "D"
145+
with pytest.raises(IndexError, match="loc must be an integer between"):
146+
result = idx[:0].insert(-1, td)

pandas/util/_validators.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313
import numpy as np
1414

15-
from pandas.core.dtypes.common import is_bool
15+
from pandas.core.dtypes.common import (
16+
is_bool,
17+
is_integer,
18+
)
1619

1720

1821
def _check_arg_length(fname, args, max_fname_arg_count, compat_args):
@@ -494,3 +497,21 @@ def validate_inclusive(inclusive: str | None) -> tuple[bool, bool]:
494497
)
495498

496499
return left_right_inclusive
500+
501+
502+
def validate_insert_loc(loc: int, length: int) -> int:
503+
"""
504+
Check that we have an integer between -length and length, inclusive.
505+
506+
Standardize negative loc to within [0, length].
507+
508+
The exceptions we raise on failure match np.insert.
509+
"""
510+
if not is_integer(loc):
511+
raise TypeError(f"loc must be an integer between -{length} and {length}")
512+
513+
if loc < 0:
514+
loc += length
515+
if not 0 <= loc <= length:
516+
raise IndexError(f"loc must be an integer between -{length} and {length}")
517+
return loc

0 commit comments

Comments
 (0)