diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index 90c86cd6d55ef..437cd975b6bb2 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -4,6 +4,7 @@ import numpy as np +from pandas._libs import missing as libmissing from pandas._typing import ( TYPE_CHECKING, DtypeObj, @@ -209,6 +210,9 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: except NotImplementedError: return None + def _is_valid_na_for_dtype(self, value) -> bool: + return value is None or value is libmissing.NA + def __from_arrow__(self, array: pa.Array | pa.ChunkedArray): """ Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray. diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index c261a41e1e77e..6d0ff112ad235 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -425,7 +425,7 @@ def __contains__(self, item: object) -> bool | np.bool_: if is_scalar(item) and isna(item): if not self._can_hold_na: return False - elif item is self.dtype.na_value or isinstance(item, self.dtype.type): + elif self.dtype._is_valid_na_for_dtype(item): return self._hasna else: return False diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index bce2a82f057f3..b8f117b8ef547 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -61,9 +61,11 @@ class ExtensionDtype: * _is_numeric * _is_boolean * _get_common_dtype + * _is_valid_na_for_dtype The `na_value` class attribute can be used to set the default NA value - for this type. :attr:`numpy.nan` is used by default. + for this type. :attr:`numpy.nan` is used by default. What other NA values + are accepted can be fine-tuned by overriding ``_is_valid_na_for_dtype``. ExtensionDtypes are required to be hashable. The base class provides a default implementation, which relies on the ``_metadata`` class @@ -390,6 +392,26 @@ def _can_hold_na(self) -> bool: """ return True + def _is_valid_na_for_dtype(self, value) -> bool: + """ + Should we treat this value as interchangeable with self.na_value? + + Use cases include: + + - series[key] = value + - series.replace(other, value) + - value in index + + If ``value`` is a valid na for this dtype, ``self.na_value`` will be used + in its place for the purpose of this operations. + + Notes + ----- + The base class implementation considers any scalar recognized by pd.isna + to be equivalent. + """ + return libmissing.checknull(value) + class StorageExtensionDtype(ExtensionDtype): """ExtensionDtype that may be backed by more than one implementation.""" diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 33ff6d1eee686..1d3221f23fd4a 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -3,6 +3,7 @@ """ from __future__ import annotations +from decimal import Decimal import re from typing import ( TYPE_CHECKING, @@ -14,7 +15,10 @@ import numpy as np import pytz -from pandas._libs import missing as libmissing +from pandas._libs import ( + lib, + missing as libmissing, +) from pandas._libs.interval import Interval from pandas._libs.properties import cache_readonly from pandas._libs.tslibs import ( @@ -629,6 +633,11 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: return find_common_type(non_cat_dtypes) + def _is_valid_na_for_dtype(self, value) -> bool: + from pandas.core.dtypes.missing import is_valid_na_for_dtype + + return is_valid_na_for_dtype(value, self.categories.dtype) + @register_extension_dtype class DatetimeTZDtype(PandasExtensionDtype): @@ -819,6 +828,10 @@ def __setstate__(self, state) -> None: self._tz = state["tz"] self._unit = state["unit"] + def _is_valid_na_for_dtype(self, value) -> bool: + # we have to rule out tznaive dt64("NaT") + return not isinstance(value, (np.timedelta64, np.datetime64, Decimal)) + @register_extension_dtype class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype): @@ -1036,6 +1049,9 @@ def __from_arrow__( return PeriodArray(np.array([], dtype="int64"), freq=self.freq, copy=False) return PeriodArray._concat_same_type(results) + def _is_valid_na_for_dtype(self, value) -> bool: + return not isinstance(value, (np.datetime64, np.timedelta64, Decimal)) + @register_extension_dtype class IntervalDtype(PandasExtensionDtype): @@ -1304,6 +1320,9 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: return np.dtype(object) return IntervalDtype(common, closed=closed) + def _is_valid_na_for_dtype(self, value) -> bool: + return lib.is_float(value) or value is None or value is libmissing.NA + class PandasDtype(ExtensionDtype): """ @@ -1479,3 +1498,6 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: return type(self).from_numpy_dtype(new_dtype) except (KeyError, NotImplementedError): return None + + def _is_valid_na_for_dtype(self, value) -> bool: + return value is None or value is libmissing.NA diff --git a/pandas/core/dtypes/missing.py b/pandas/core/dtypes/missing.py index 211b67d3590ed..f111bd97a00de 100644 --- a/pandas/core/dtypes/missing.py +++ b/pandas/core/dtypes/missing.py @@ -38,10 +38,8 @@ needs_i8_conversion, ) from pandas.core.dtypes.dtypes import ( - CategoricalDtype, DatetimeTZDtype, ExtensionDtype, - IntervalDtype, PeriodDtype, ) from pandas.core.dtypes.generic import ( @@ -692,40 +690,33 @@ def is_valid_na_for_dtype(obj, dtype: DtypeObj) -> bool: """ if not lib.is_scalar(obj) or not isna(obj): return False - elif dtype.kind == "M": - if isinstance(dtype, np.dtype): + + elif isinstance(dtype, np.dtype): + if dtype.kind == "M": # i.e. not tzaware return not isinstance(obj, (np.timedelta64, Decimal)) - # we have to rule out tznaive dt64("NaT") - return not isinstance(obj, (np.timedelta64, np.datetime64, Decimal)) - elif dtype.kind == "m": - return not isinstance(obj, (np.datetime64, Decimal)) - elif dtype.kind in ["i", "u", "f", "c"]: - # Numeric - return obj is not NaT and not isinstance(obj, (np.datetime64, np.timedelta64)) - elif dtype.kind == "b": - # We allow pd.NA, None, np.nan in BooleanArray (same as IntervalDtype) - return lib.is_float(obj) or obj is None or obj is libmissing.NA - - elif dtype == _dtype_str: - # numpy string dtypes to avoid float np.nan - return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal, float)) - - elif dtype == _dtype_object: - # This is needed for Categorical, but is kind of weird - return True - - elif isinstance(dtype, PeriodDtype): - return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal)) - - elif isinstance(dtype, IntervalDtype): - return lib.is_float(obj) or obj is None or obj is libmissing.NA + elif dtype.kind == "m": + return not isinstance(obj, (np.datetime64, Decimal)) + elif dtype.kind in ["i", "u", "f", "c"]: + # Numeric + return obj is not NaT and not isinstance( + obj, (np.datetime64, np.timedelta64) + ) + elif dtype == _dtype_str: + # numpy string dtypes to avoid float np.nan + return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal, float)) + elif dtype == _dtype_object: + # This is needed for Categorical, but is kind of weird + return True + elif dtype.kind == "b": + # We allow pd.NA, None, np.nan in BooleanArray (same as IntervalDtype) + return lib.is_float(obj) or obj is None or obj is libmissing.NA - elif isinstance(dtype, CategoricalDtype): - return is_valid_na_for_dtype(obj, dtype.categories.dtype) + # fallback, default to allowing NaN, None, NA, NaT + return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal)) - # fallback, default to allowing NaN, None, NA, NaT - return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal)) + else: + return dtype._is_valid_na_for_dtype(obj) def isna_all(arr: ArrayLike) -> bool: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index cb321f0584294..7807f0a170bd1 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -35,7 +35,10 @@ using_copy_on_write, ) -from pandas._libs import lib +from pandas._libs import ( + lib, + missing as libmissing, +) from pandas._libs.lib import is_range_indexer from pandas._libs.tslibs import ( Period, @@ -8012,7 +8015,7 @@ def _clip_with_scalar(self, lower, upper, inplace: bool_t = False): result = result.where(subset, lower, axis=None, inplace=False) if np.any(mask): - result[mask] = np.nan + result[mask] = libmissing.NA if inplace: return self._update_inplace(result) diff --git a/pandas/tests/extension/test_integer.py b/pandas/tests/extension/test_integer.py index 936764c3627d0..65543cd6bbf1c 100644 --- a/pandas/tests/extension/test_integer.py +++ b/pandas/tests/extension/test_integer.py @@ -135,7 +135,7 @@ def _check_op(self, s, op, other, op_name, exc=NotImplementedError): expected = self._combine(s, other, op) if op_name in ("__rtruediv__", "__truediv__", "__div__"): - expected = expected.fillna(np.nan).astype("Float64") + expected = expected.fillna(pd.NA).astype("Float64") else: # combine method result in 'biggest' (int64) dtype expected = expected.astype(sdtype) diff --git a/pandas/tests/series/methods/test_convert_dtypes.py b/pandas/tests/series/methods/test_convert_dtypes.py index d91cd6a43daea..2d12ccab6f95c 100644 --- a/pandas/tests/series/methods/test_convert_dtypes.py +++ b/pandas/tests/series/methods/test_convert_dtypes.py @@ -193,7 +193,7 @@ def test_convert_dtypes( # Test that it is a copy copy = series.copy(deep=True) - result[result.notna()] = np.nan + result[result.notna()] = None # Make sure original not changed tm.assert_series_equal(series, copy)