Skip to content

WIP: dtype._is_valid_na_for_dtype #51378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from pandas._libs import missing as libmissing
from pandas._typing import (
TYPE_CHECKING,
DtypeObj,
Expand Down Expand Up @@ -209,6 +210,9 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
except NotImplementedError:
return None

def _is_valid_na_for_dtype(self, value) -> bool:
return value is None or value is libmissing.NA

def __from_arrow__(self, array: pa.Array | pa.ChunkedArray):
"""
Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray.
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def __contains__(self, item: object) -> bool | np.bool_:
if is_scalar(item) and isna(item):
if not self._can_hold_na:
return False
elif item is self.dtype.na_value or isinstance(item, self.dtype.type):
elif self.dtype._is_valid_na_for_dtype(item):
return self._hasna
else:
return False
Expand Down
24 changes: 23 additions & 1 deletion pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ class ExtensionDtype:
* _is_numeric
* _is_boolean
* _get_common_dtype
* _is_valid_na_for_dtype

The `na_value` class attribute can be used to set the default NA value
for this type. :attr:`numpy.nan` is used by default.
for this type. :attr:`numpy.nan` is used by default. What other NA values
are accepted can be fine-tuned by overriding ``_is_valid_na_for_dtype``.

ExtensionDtypes are required to be hashable. The base class provides
a default implementation, which relies on the ``_metadata`` class
Expand Down Expand Up @@ -390,6 +392,26 @@ def _can_hold_na(self) -> bool:
"""
return True

def _is_valid_na_for_dtype(self, value) -> bool:
"""
Should we treat this value as interchangeable with self.na_value?

Use cases include:

- series[key] = value
- series.replace(other, value)
- value in index

If ``value`` is a valid na for this dtype, ``self.na_value`` will be used
in its place for the purpose of this operations.

Notes
-----
The base class implementation considers any scalar recognized by pd.isna
to be equivalent.
"""
return libmissing.checknull(value)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could require that pd.NA always be recognized?



class StorageExtensionDtype(ExtensionDtype):
"""ExtensionDtype that may be backed by more than one implementation."""
Expand Down
24 changes: 23 additions & 1 deletion pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import annotations

from decimal import Decimal
import re
from typing import (
TYPE_CHECKING,
Expand All @@ -14,7 +15,10 @@
import numpy as np
import pytz

from pandas._libs import missing as libmissing
from pandas._libs import (
lib,
missing as libmissing,
)
from pandas._libs.interval import Interval
from pandas._libs.properties import cache_readonly
from pandas._libs.tslibs import (
Expand Down Expand Up @@ -629,6 +633,11 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:

return find_common_type(non_cat_dtypes)

def _is_valid_na_for_dtype(self, value) -> bool:
from pandas.core.dtypes.missing import is_valid_na_for_dtype

return is_valid_na_for_dtype(value, self.categories.dtype)


@register_extension_dtype
class DatetimeTZDtype(PandasExtensionDtype):
Expand Down Expand Up @@ -819,6 +828,10 @@ def __setstate__(self, state) -> None:
self._tz = state["tz"]
self._unit = state["unit"]

def _is_valid_na_for_dtype(self, value) -> bool:
# we have to rule out tznaive dt64("NaT")
return not isinstance(value, (np.timedelta64, np.datetime64, Decimal))


@register_extension_dtype
class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype):
Expand Down Expand Up @@ -1036,6 +1049,9 @@ def __from_arrow__(
return PeriodArray(np.array([], dtype="int64"), freq=self.freq, copy=False)
return PeriodArray._concat_same_type(results)

def _is_valid_na_for_dtype(self, value) -> bool:
return not isinstance(value, (np.datetime64, np.timedelta64, Decimal))


@register_extension_dtype
class IntervalDtype(PandasExtensionDtype):
Expand Down Expand Up @@ -1304,6 +1320,9 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
return np.dtype(object)
return IntervalDtype(common, closed=closed)

def _is_valid_na_for_dtype(self, value) -> bool:
return lib.is_float(value) or value is None or value is libmissing.NA


class PandasDtype(ExtensionDtype):
"""
Expand Down Expand Up @@ -1479,3 +1498,6 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
return type(self).from_numpy_dtype(new_dtype)
except (KeyError, NotImplementedError):
return None

def _is_valid_na_for_dtype(self, value) -> bool:
return value is None or value is libmissing.NA
55 changes: 23 additions & 32 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
DatetimeTZDtype,
ExtensionDtype,
IntervalDtype,
PeriodDtype,
)
from pandas.core.dtypes.generic import (
Expand Down Expand Up @@ -692,40 +690,33 @@ def is_valid_na_for_dtype(obj, dtype: DtypeObj) -> bool:
"""
if not lib.is_scalar(obj) or not isna(obj):
return False
elif dtype.kind == "M":
if isinstance(dtype, np.dtype):

elif isinstance(dtype, np.dtype):
if dtype.kind == "M":
# i.e. not tzaware
return not isinstance(obj, (np.timedelta64, Decimal))
# we have to rule out tznaive dt64("NaT")
return not isinstance(obj, (np.timedelta64, np.datetime64, Decimal))
elif dtype.kind == "m":
return not isinstance(obj, (np.datetime64, Decimal))
elif dtype.kind in ["i", "u", "f", "c"]:
# Numeric
return obj is not NaT and not isinstance(obj, (np.datetime64, np.timedelta64))
elif dtype.kind == "b":
# We allow pd.NA, None, np.nan in BooleanArray (same as IntervalDtype)
return lib.is_float(obj) or obj is None or obj is libmissing.NA

elif dtype == _dtype_str:
# numpy string dtypes to avoid float np.nan
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal, float))

elif dtype == _dtype_object:
# This is needed for Categorical, but is kind of weird
return True

elif isinstance(dtype, PeriodDtype):
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal))

elif isinstance(dtype, IntervalDtype):
return lib.is_float(obj) or obj is None or obj is libmissing.NA
elif dtype.kind == "m":
return not isinstance(obj, (np.datetime64, Decimal))
elif dtype.kind in ["i", "u", "f", "c"]:
# Numeric
return obj is not NaT and not isinstance(
obj, (np.datetime64, np.timedelta64)
)
elif dtype == _dtype_str:
# numpy string dtypes to avoid float np.nan
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal, float))
elif dtype == _dtype_object:
# This is needed for Categorical, but is kind of weird
return True
elif dtype.kind == "b":
# We allow pd.NA, None, np.nan in BooleanArray (same as IntervalDtype)
return lib.is_float(obj) or obj is None or obj is libmissing.NA

elif isinstance(dtype, CategoricalDtype):
return is_valid_na_for_dtype(obj, dtype.categories.dtype)
# fallback, default to allowing NaN, None, NA, NaT
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal))

# fallback, default to allowing NaN, None, NA, NaT
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal))
else:
return dtype._is_valid_na_for_dtype(obj)


def isna_all(arr: ArrayLike) -> bool:
Expand Down
7 changes: 5 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
using_copy_on_write,
)

from pandas._libs import lib
from pandas._libs import (
lib,
missing as libmissing,
)
from pandas._libs.lib import is_range_indexer
from pandas._libs.tslibs import (
Period,
Expand Down Expand Up @@ -8012,7 +8015,7 @@ def _clip_with_scalar(self, lower, upper, inplace: bool_t = False):
result = result.where(subset, lower, axis=None, inplace=False)

if np.any(mask):
result[mask] = np.nan
result[mask] = libmissing.NA

if inplace:
return self._update_inplace(result)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
expected = self._combine(s, other, op)

if op_name in ("__rtruediv__", "__truediv__", "__div__"):
expected = expected.fillna(np.nan).astype("Float64")
expected = expected.fillna(pd.NA).astype("Float64")
else:
# combine method result in 'biggest' (int64) dtype
expected = expected.astype(sdtype)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/series/methods/test_convert_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_convert_dtypes(
# Test that it is a copy
copy = series.copy(deep=True)

result[result.notna()] = np.nan
result[result.notna()] = None

# Make sure original not changed
tm.assert_series_equal(series, copy)
Expand Down