diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 3a3f0b8ce61be..18caee3249d32 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -506,6 +506,18 @@ def __len__(self) -> int: """ return len(self._data) + def __contains__(self, key) -> bool: + # https://github.com/pandas-dev/pandas/pull/51307#issuecomment-1426372604 + if isna(key) and key is not self.dtype.na_value: + if self.dtype.kind == "f" and lib.is_float(key) and isna(key): + return pc.any(pc.is_nan(self._data)).as_py() + + # e.g. date or timestamp types we do not allow None here to match pd.NA + return False + # TODO: maybe complex? object? + + return bool(super().__contains__(key)) + @property def _hasna(self) -> bool: return self._data.null_count > 0 diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index 90c86cd6d55ef..331e66698cc35 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -1,9 +1,20 @@ from __future__ import annotations +from datetime import ( + date, + datetime, + time, + timedelta, +) +from decimal import Decimal import re import numpy as np +from pandas._libs.tslibs import ( + Timedelta, + Timestamp, +) from pandas._typing import ( TYPE_CHECKING, DtypeObj, @@ -88,9 +99,40 @@ def __repr__(self) -> str: @property def type(self): """ - Returns pyarrow.DataType. + Returns associated scalar type. """ - return type(self.pyarrow_dtype) + pa_type = self.pyarrow_dtype + if pa.types.is_integer(pa_type): + return int + elif pa.types.is_floating(pa_type): + return float + elif pa.types.is_string(pa_type): + return str + elif pa.types.is_binary(pa_type): + return bytes + elif pa.types.is_boolean(pa_type): + return bool + elif pa.types.is_duration(pa_type): + if pa_type.unit == "ns": + return Timedelta + else: + return timedelta + elif pa.types.is_timestamp(pa_type): + if pa_type.unit == "ns": + return Timestamp + else: + return datetime + elif pa.types.is_date(pa_type): + return date + elif pa.types.is_time(pa_type): + return time + elif pa.types.is_decimal(pa_type): + return Decimal + elif pa.types.is_null(pa_type): + # TODO: None? pd.NA? pa.null? + return type(pa_type) + else: + raise NotImplementedError(pa_type) @property def name(self) -> str: # type: ignore[override] diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 60d022a0c7964..55afa95129e83 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -15,7 +15,6 @@ import numpy as np from pandas._libs import ( - Timestamp, internals as libinternals, lib, writers, @@ -63,6 +62,7 @@ is_string_dtype, ) from pandas.core.dtypes.dtypes import ( + DatetimeTZDtype, ExtensionDtype, PandasDtype, PeriodDtype, @@ -2232,9 +2232,8 @@ def get_block_type(dtype: DtypeObj): ------- cls : class, subclass of Block """ - # We use vtype and kind checks because they are much more performant + # We use kind checks because it is much more performant # than is_foo_dtype - vtype = dtype.type kind = dtype.kind cls: type[Block] @@ -2242,7 +2241,7 @@ def get_block_type(dtype: DtypeObj): if isinstance(dtype, SparseDtype): # Need this first(ish) so that Sparse[datetime] is sparse cls = ExtensionBlock - elif vtype is Timestamp: + elif isinstance(dtype, DatetimeTZDtype): cls = DatetimeTZBlock elif isinstance(dtype, PeriodDtype): cls = NDArrayBackedExtensionBlock diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 681d048f38485..27b30fd6aab6f 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -323,39 +323,7 @@ def test_from_sequence_of_strings_pa_array(self, data, request): class TestGetitemTests(base.BaseGetitemTests): - def test_getitem_scalar(self, data): - # In the base class we expect data.dtype.type; but this (intentionally) - # returns Python scalars or pd.NA - pa_type = data._data.type - if pa.types.is_integer(pa_type): - exp_type = int - elif pa.types.is_floating(pa_type): - exp_type = float - elif pa.types.is_string(pa_type): - exp_type = str - elif pa.types.is_binary(pa_type): - exp_type = bytes - elif pa.types.is_boolean(pa_type): - exp_type = bool - elif pa.types.is_duration(pa_type): - exp_type = timedelta - elif pa.types.is_timestamp(pa_type): - if pa_type.unit == "ns": - exp_type = pd.Timestamp - else: - exp_type = datetime - elif pa.types.is_date(pa_type): - exp_type = date - elif pa.types.is_time(pa_type): - exp_type = time - else: - raise NotImplementedError(data.dtype) - - result = data[0] - assert isinstance(result, exp_type), type(result) - - result = pd.Series(data)[0] - assert isinstance(result, exp_type), type(result) + pass class TestBaseAccumulateTests(base.BaseAccumulateTests):