Skip to content

Commit 6821872

Browse files
authored
BUG: NumericIndex.insert (#43933)
1 parent f4e7fdc commit 6821872

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

pandas/core/indexes/base.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -6329,10 +6329,11 @@ def insert(self, loc: int, item) -> Index:
63296329

63306330
arr = np.asarray(self)
63316331

6332-
# Use Index constructor to ensure we get tuples cast correctly.
6333-
item = Index([item], dtype=self.dtype)._values
6332+
# Use constructor to ensure we get tuples cast correctly.
6333+
# Use self._constructor instead of Index to retain NumericIndex GH#43921
6334+
item = self._constructor([item], dtype=self.dtype)._values
63346335
idx = np.concatenate((arr[:loc], item, arr[loc:]))
6335-
return Index._with_infer(idx, name=self.name)
6336+
return self._constructor._with_infer(idx, name=self.name)
63366337

63376338
def drop(self, labels, errors: str_t = "raise") -> Index:
63386339
"""

pandas/tests/indexes/common.py

+19
Original file line numberDiff line numberDiff line change
@@ -793,13 +793,32 @@ def test_format(self, simple_index):
793793
def test_numeric_compat(self):
794794
pass # override Base method
795795

796+
def test_insert_non_na(self, simple_index):
797+
# GH#43921 inserting an element that we know we can hold should
798+
# not change dtype or type (except for RangeIndex)
799+
index = simple_index
800+
801+
result = index.insert(0, index[0])
802+
803+
cls = type(index)
804+
if cls is RangeIndex:
805+
cls = Int64Index
806+
807+
expected = cls([index[0]] + list(index), dtype=index.dtype)
808+
tm.assert_index_equal(result, expected)
809+
796810
def test_insert_na(self, nulls_fixture, simple_index):
797811
# GH 18295 (test missing)
798812
index = simple_index
799813
na_val = nulls_fixture
800814

801815
if na_val is pd.NaT:
802816
expected = Index([index[0], pd.NaT] + list(index[1:]), dtype=object)
817+
elif type(index) is NumericIndex and index.dtype.kind == "f":
818+
# GH#43921
819+
expected = NumericIndex(
820+
[index[0], np.nan] + list(index[1:]), dtype=index.dtype
821+
)
803822
else:
804823
expected = Float64Index([index[0], np.nan] + list(index[1:]))
805824

0 commit comments

Comments
 (0)