Skip to content

Commit 67305b2

Browse files
authored
REF: use _validate_fill_value in Index.insert (#38102)
1 parent 8f81d0e commit 67305b2

File tree

4 files changed

+65
-41
lines changed

4 files changed

+65
-41
lines changed

pandas/core/dtypes/cast.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import numpy as np
2222

23-
from pandas._libs import lib, tslib
23+
from pandas._libs import lib, missing as libmissing, tslib
2424
from pandas._libs.tslibs import (
2525
NaT,
2626
OutOfBoundsDatetime,
@@ -519,6 +519,11 @@ def maybe_promote(dtype, fill_value=np.nan):
519519
Upcasted from dtype argument if necessary.
520520
fill_value
521521
Upcasted from fill_value argument if necessary.
522+
523+
Raises
524+
------
525+
ValueError
526+
If fill_value is a non-scalar and dtype is not object.
522527
"""
523528
if not is_scalar(fill_value) and not is_object_dtype(dtype):
524529
# with object dtype there is nothing to promote, and the user can
@@ -550,6 +555,9 @@ def maybe_promote(dtype, fill_value=np.nan):
550555
dtype = np.dtype(np.object_)
551556
elif is_integer(fill_value) or (is_float(fill_value) and not isna(fill_value)):
552557
dtype = np.dtype(np.object_)
558+
elif is_valid_nat_for_dtype(fill_value, dtype):
559+
# e.g. pd.NA, which is not accepted by Timestamp constructor
560+
fill_value = np.datetime64("NaT", "ns")
553561
else:
554562
try:
555563
fill_value = Timestamp(fill_value).to_datetime64()
@@ -563,6 +571,9 @@ def maybe_promote(dtype, fill_value=np.nan):
563571
):
564572
# TODO: What about str that can be a timedelta?
565573
dtype = np.dtype(np.object_)
574+
elif is_valid_nat_for_dtype(fill_value, dtype):
575+
# e.g pd.NA, which is not accepted by the Timedelta constructor
576+
fill_value = np.timedelta64("NaT", "ns")
566577
else:
567578
try:
568579
fv = Timedelta(fill_value)
@@ -636,7 +647,7 @@ def maybe_promote(dtype, fill_value=np.nan):
636647
# e.g. mst is np.complex128 and dtype is np.complex64
637648
dtype = mst
638649

639-
elif fill_value is None:
650+
elif fill_value is None or fill_value is libmissing.NA:
640651
if is_float_dtype(dtype) or is_complex_dtype(dtype):
641652
fill_value = np.nan
642653
elif is_integer_dtype(dtype):
@@ -646,7 +657,8 @@ def maybe_promote(dtype, fill_value=np.nan):
646657
fill_value = dtype.type("NaT", "ns")
647658
else:
648659
dtype = np.dtype(np.object_)
649-
fill_value = np.nan
660+
if fill_value is not libmissing.NA:
661+
fill_value = np.nan
650662
else:
651663
dtype = np.dtype(np.object_)
652664

pandas/core/indexes/base.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pandas.core.dtypes.cast import (
3636
find_common_type,
3737
maybe_cast_to_integer_array,
38+
maybe_promote,
3839
validate_numeric_casting,
3940
)
4041
from pandas.core.dtypes.common import (
@@ -4196,28 +4197,15 @@ def _string_data_error(cls, data):
41964197
"to explicitly cast to a numeric type"
41974198
)
41984199

4199-
@final
4200-
def _coerce_scalar_to_index(self, item):
4201-
"""
4202-
We need to coerce a scalar to a compat for our index type.
4203-
4204-
Parameters
4205-
----------
4206-
item : scalar item to coerce
4207-
"""
4208-
dtype = self.dtype
4209-
4210-
if self._is_numeric_dtype and isna(item):
4211-
# We can't coerce to the numeric dtype of "self" (unless
4212-
# it's float) if there are NaN values in our output.
4213-
dtype = None
4214-
4215-
return Index([item], dtype=dtype, **self._get_attributes_dict())
4216-
42174200
def _validate_fill_value(self, value):
42184201
"""
4219-
Check if the value can be inserted into our array, and convert
4220-
it to an appropriate native type if necessary.
4202+
Check if the value can be inserted into our array without casting,
4203+
and convert it to an appropriate native type if necessary.
4204+
4205+
Raises
4206+
------
4207+
TypeError
4208+
If the value cannot be inserted into an array of this dtype.
42214209
"""
42224210
return value
42234211

@@ -5583,8 +5571,22 @@ def insert(self, loc: int, item):
55835571
"""
55845572
# Note: this method is overridden by all ExtensionIndex subclasses,
55855573
# so self is never backed by an EA.
5574+
5575+
try:
5576+
item = self._validate_fill_value(item)
5577+
except TypeError:
5578+
if is_scalar(item):
5579+
dtype, item = maybe_promote(self.dtype, item)
5580+
else:
5581+
# maybe_promote would raise ValueError
5582+
dtype = np.dtype(object)
5583+
5584+
return self.astype(dtype).insert(loc, item)
5585+
55865586
arr = np.asarray(self)
5587-
item = self._coerce_scalar_to_index(item)._values
5587+
5588+
# Use Index constructor to ensure we get tuples cast correctly.
5589+
item = Index([item], dtype=self.dtype)._values
55885590
idx = np.concatenate((arr[:loc], item, arr[loc:]))
55895591
return Index(idx, name=self.name)
55905592

pandas/core/indexes/numeric.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
is_float,
1717
is_float_dtype,
1818
is_integer_dtype,
19+
is_number,
1920
is_numeric_dtype,
2021
is_scalar,
2122
is_signed_integer_dtype,
@@ -112,23 +113,36 @@ def _shallow_copy(self, values=None, name: Label = lib.no_default):
112113
return Float64Index._simple_new(values, name=name)
113114
return super()._shallow_copy(values=values, name=name)
114115

116+
@doc(Index._validate_fill_value)
115117
def _validate_fill_value(self, value):
116-
"""
117-
Convert value to be insertable to ndarray.
118-
"""
119118
if is_bool(value) or is_bool_dtype(value):
120119
# force conversion to object
121120
# so we don't lose the bools
122121
raise TypeError
123-
elif isinstance(value, str) or lib.is_complex(value):
124-
raise TypeError
125122
elif is_scalar(value) and isna(value):
126123
if is_valid_nat_for_dtype(value, self.dtype):
127124
value = self._na_value
125+
if self.dtype.kind != "f":
126+
# raise so that caller can cast
127+
raise TypeError
128128
else:
129129
# NaT, np.datetime64("NaT"), np.timedelta64("NaT")
130130
raise TypeError
131131

132+
elif is_scalar(value):
133+
if not is_number(value):
134+
# e.g. datetime64, timedelta64, datetime, ...
135+
raise TypeError
136+
137+
elif lib.is_complex(value):
138+
# at least until we have a ComplexIndx
139+
raise TypeError
140+
141+
elif is_float(value) and self.dtype.kind != "f":
142+
if not value.is_integer():
143+
raise TypeError
144+
value = int(value)
145+
132146
return value
133147

134148
def _convert_tolerance(self, tolerance, target):
@@ -168,15 +182,6 @@ def _is_all_dates(self) -> bool:
168182
"""
169183
return False
170184

171-
@doc(Index.insert)
172-
def insert(self, loc: int, item):
173-
try:
174-
item = self._validate_fill_value(item)
175-
except TypeError:
176-
return self.astype(object).insert(loc, item)
177-
178-
return super().insert(loc, item)
179-
180185
def _union(self, other, sort):
181186
# Right now, we treat union(int, float) a bit special.
182187
# See https://github.com/pandas-dev/pandas/issues/26778 for discussion

pandas/tests/dtypes/cast/test_promote.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def _assert_match(result_fill_value, expected_fill_value):
110110
assert res_type == ex_type or res_type.__name__ == ex_type.__name__
111111

112112
match_value = result_fill_value == expected_fill_value
113+
if match_value is pd.NA:
114+
match_value = False
113115

114116
# Note: type check above ensures that we have the _same_ NA value
115117
# for missing values, None == None (which is checked
@@ -569,8 +571,8 @@ def test_maybe_promote_any_with_object(any_numpy_dtype_reduced, object_dtype):
569571
_check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar)
570572

571573

572-
@pytest.mark.parametrize("fill_value", [None, np.nan, NaT])
573-
def test_maybe_promote_any_numpy_dtype_with_na(any_numpy_dtype_reduced, fill_value):
574+
def test_maybe_promote_any_numpy_dtype_with_na(any_numpy_dtype_reduced, nulls_fixture):
575+
fill_value = nulls_fixture
574576
dtype = np.dtype(any_numpy_dtype_reduced)
575577

576578
if is_integer_dtype(dtype) and fill_value is not NaT:
@@ -597,7 +599,10 @@ def test_maybe_promote_any_numpy_dtype_with_na(any_numpy_dtype_reduced, fill_val
597599
else:
598600
# all other cases cast to object, and use np.nan as missing value
599601
expected_dtype = np.dtype(object)
600-
exp_val_for_scalar = np.nan
602+
if fill_value is pd.NA:
603+
exp_val_for_scalar = pd.NA
604+
else:
605+
exp_val_for_scalar = np.nan
601606

602607
_check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar)
603608

0 commit comments

Comments
 (0)