diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 7d2e3746c4b94..3274725016b40 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -20,7 +20,7 @@ import numpy as np -from pandas._libs import lib, tslib +from pandas._libs import lib, missing as libmissing, tslib from pandas._libs.tslibs import ( NaT, OutOfBoundsDatetime, @@ -519,6 +519,11 @@ def maybe_promote(dtype, fill_value=np.nan): Upcasted from dtype argument if necessary. fill_value Upcasted from fill_value argument if necessary. + + Raises + ------ + ValueError + If fill_value is a non-scalar and dtype is not object. """ if not is_scalar(fill_value) and not is_object_dtype(dtype): # with object dtype there is nothing to promote, and the user can @@ -550,6 +555,9 @@ def maybe_promote(dtype, fill_value=np.nan): dtype = np.dtype(np.object_) elif is_integer(fill_value) or (is_float(fill_value) and not isna(fill_value)): dtype = np.dtype(np.object_) + elif is_valid_nat_for_dtype(fill_value, dtype): + # e.g. pd.NA, which is not accepted by Timestamp constructor + fill_value = np.datetime64("NaT", "ns") else: try: fill_value = Timestamp(fill_value).to_datetime64() @@ -563,6 +571,9 @@ def maybe_promote(dtype, fill_value=np.nan): ): # TODO: What about str that can be a timedelta? dtype = np.dtype(np.object_) + elif is_valid_nat_for_dtype(fill_value, dtype): + # e.g pd.NA, which is not accepted by the Timedelta constructor + fill_value = np.timedelta64("NaT", "ns") else: try: fv = Timedelta(fill_value) @@ -636,7 +647,7 @@ def maybe_promote(dtype, fill_value=np.nan): # e.g. mst is np.complex128 and dtype is np.complex64 dtype = mst - elif fill_value is None: + elif fill_value is None or fill_value is libmissing.NA: if is_float_dtype(dtype) or is_complex_dtype(dtype): fill_value = np.nan elif is_integer_dtype(dtype): @@ -646,7 +657,8 @@ def maybe_promote(dtype, fill_value=np.nan): fill_value = dtype.type("NaT", "ns") else: dtype = np.dtype(np.object_) - fill_value = np.nan + if fill_value is not libmissing.NA: + fill_value = np.nan else: dtype = np.dtype(np.object_) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index e83bc9c1448eb..2261a6a20e58f 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -35,6 +35,7 @@ from pandas.core.dtypes.cast import ( find_common_type, maybe_cast_to_integer_array, + maybe_promote, validate_numeric_casting, ) from pandas.core.dtypes.common import ( @@ -4196,28 +4197,15 @@ def _string_data_error(cls, data): "to explicitly cast to a numeric type" ) - @final - def _coerce_scalar_to_index(self, item): - """ - We need to coerce a scalar to a compat for our index type. - - Parameters - ---------- - item : scalar item to coerce - """ - dtype = self.dtype - - if self._is_numeric_dtype and isna(item): - # We can't coerce to the numeric dtype of "self" (unless - # it's float) if there are NaN values in our output. - dtype = None - - return Index([item], dtype=dtype, **self._get_attributes_dict()) - def _validate_fill_value(self, value): """ - Check if the value can be inserted into our array, and convert - it to an appropriate native type if necessary. + Check if the value can be inserted into our array without casting, + and convert it to an appropriate native type if necessary. + + Raises + ------ + TypeError + If the value cannot be inserted into an array of this dtype. """ return value @@ -5583,8 +5571,22 @@ def insert(self, loc: int, item): """ # Note: this method is overridden by all ExtensionIndex subclasses, # so self is never backed by an EA. + + try: + item = self._validate_fill_value(item) + except TypeError: + if is_scalar(item): + dtype, item = maybe_promote(self.dtype, item) + else: + # maybe_promote would raise ValueError + dtype = np.dtype(object) + + return self.astype(dtype).insert(loc, item) + arr = np.asarray(self) - item = self._coerce_scalar_to_index(item)._values + + # Use Index constructor to ensure we get tuples cast correctly. + item = Index([item], dtype=self.dtype)._values idx = np.concatenate((arr[:loc], item, arr[loc:])) return Index(idx, name=self.name) diff --git a/pandas/core/indexes/numeric.py b/pandas/core/indexes/numeric.py index ed76e26a57634..117200ee53116 100644 --- a/pandas/core/indexes/numeric.py +++ b/pandas/core/indexes/numeric.py @@ -16,6 +16,7 @@ is_float, is_float_dtype, is_integer_dtype, + is_number, is_numeric_dtype, is_scalar, is_signed_integer_dtype, @@ -112,23 +113,36 @@ def _shallow_copy(self, values=None, name: Label = lib.no_default): return Float64Index._simple_new(values, name=name) return super()._shallow_copy(values=values, name=name) + @doc(Index._validate_fill_value) def _validate_fill_value(self, value): - """ - Convert value to be insertable to ndarray. - """ if is_bool(value) or is_bool_dtype(value): # force conversion to object # so we don't lose the bools raise TypeError - elif isinstance(value, str) or lib.is_complex(value): - raise TypeError elif is_scalar(value) and isna(value): if is_valid_nat_for_dtype(value, self.dtype): value = self._na_value + if self.dtype.kind != "f": + # raise so that caller can cast + raise TypeError else: # NaT, np.datetime64("NaT"), np.timedelta64("NaT") raise TypeError + elif is_scalar(value): + if not is_number(value): + # e.g. datetime64, timedelta64, datetime, ... + raise TypeError + + elif lib.is_complex(value): + # at least until we have a ComplexIndx + raise TypeError + + elif is_float(value) and self.dtype.kind != "f": + if not value.is_integer(): + raise TypeError + value = int(value) + return value def _convert_tolerance(self, tolerance, target): @@ -168,15 +182,6 @@ def _is_all_dates(self) -> bool: """ return False - @doc(Index.insert) - def insert(self, loc: int, item): - try: - item = self._validate_fill_value(item) - except TypeError: - return self.astype(object).insert(loc, item) - - return super().insert(loc, item) - def _union(self, other, sort): # Right now, we treat union(int, float) a bit special. # See https://github.com/pandas-dev/pandas/issues/26778 for discussion diff --git a/pandas/tests/dtypes/cast/test_promote.py b/pandas/tests/dtypes/cast/test_promote.py index 74a11c9f33195..294abafa86812 100644 --- a/pandas/tests/dtypes/cast/test_promote.py +++ b/pandas/tests/dtypes/cast/test_promote.py @@ -110,6 +110,8 @@ def _assert_match(result_fill_value, expected_fill_value): assert res_type == ex_type or res_type.__name__ == ex_type.__name__ match_value = result_fill_value == expected_fill_value + if match_value is pd.NA: + match_value = False # Note: type check above ensures that we have the _same_ NA value # for missing values, None == None (which is checked @@ -569,8 +571,8 @@ def test_maybe_promote_any_with_object(any_numpy_dtype_reduced, object_dtype): _check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar) -@pytest.mark.parametrize("fill_value", [None, np.nan, NaT]) -def test_maybe_promote_any_numpy_dtype_with_na(any_numpy_dtype_reduced, fill_value): +def test_maybe_promote_any_numpy_dtype_with_na(any_numpy_dtype_reduced, nulls_fixture): + fill_value = nulls_fixture dtype = np.dtype(any_numpy_dtype_reduced) if is_integer_dtype(dtype) and fill_value is not NaT: @@ -597,7 +599,10 @@ def test_maybe_promote_any_numpy_dtype_with_na(any_numpy_dtype_reduced, fill_val else: # all other cases cast to object, and use np.nan as missing value expected_dtype = np.dtype(object) - exp_val_for_scalar = np.nan + if fill_value is pd.NA: + exp_val_for_scalar = pd.NA + else: + exp_val_for_scalar = np.nan _check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar)