Skip to content

Commit dba9ddd

Browse files
authored
API: ArrowDtype.type (#51307)
* API: ArrowDtype.type * handle Decimal, update docstring, update tests * patch __contains__ * mypy fixup * mypy fixup * mypy fixup * use pc.any(pc.is_nan...
1 parent e1ce580 commit dba9ddd

File tree

4 files changed

+60
-39
lines changed

4 files changed

+60
-39
lines changed

pandas/core/arrays/arrow/array.py

+12
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,18 @@ def __len__(self) -> int:
510510
"""
511511
return len(self._data)
512512

513+
def __contains__(self, key) -> bool:
514+
# https://github.com/pandas-dev/pandas/pull/51307#issuecomment-1426372604
515+
if isna(key) and key is not self.dtype.na_value:
516+
if self.dtype.kind == "f" and lib.is_float(key) and isna(key):
517+
return pc.any(pc.is_nan(self._data)).as_py()
518+
519+
# e.g. date or timestamp types we do not allow None here to match pd.NA
520+
return False
521+
# TODO: maybe complex? object?
522+
523+
return bool(super().__contains__(key))
524+
513525
@property
514526
def _hasna(self) -> bool:
515527
return self._data.null_count > 0

pandas/core/arrays/arrow/dtype.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
from __future__ import annotations
22

3+
from datetime import (
4+
date,
5+
datetime,
6+
time,
7+
timedelta,
8+
)
9+
from decimal import Decimal
310
import re
411

512
import numpy as np
613

14+
from pandas._libs.tslibs import (
15+
Timedelta,
16+
Timestamp,
17+
)
718
from pandas._typing import (
819
TYPE_CHECKING,
920
DtypeObj,
@@ -88,9 +99,40 @@ def __repr__(self) -> str:
8899
@property
89100
def type(self):
90101
"""
91-
Returns pyarrow.DataType.
102+
Returns associated scalar type.
92103
"""
93-
return type(self.pyarrow_dtype)
104+
pa_type = self.pyarrow_dtype
105+
if pa.types.is_integer(pa_type):
106+
return int
107+
elif pa.types.is_floating(pa_type):
108+
return float
109+
elif pa.types.is_string(pa_type):
110+
return str
111+
elif pa.types.is_binary(pa_type):
112+
return bytes
113+
elif pa.types.is_boolean(pa_type):
114+
return bool
115+
elif pa.types.is_duration(pa_type):
116+
if pa_type.unit == "ns":
117+
return Timedelta
118+
else:
119+
return timedelta
120+
elif pa.types.is_timestamp(pa_type):
121+
if pa_type.unit == "ns":
122+
return Timestamp
123+
else:
124+
return datetime
125+
elif pa.types.is_date(pa_type):
126+
return date
127+
elif pa.types.is_time(pa_type):
128+
return time
129+
elif pa.types.is_decimal(pa_type):
130+
return Decimal
131+
elif pa.types.is_null(pa_type):
132+
# TODO: None? pd.NA? pa.null?
133+
return type(pa_type)
134+
else:
135+
raise NotImplementedError(pa_type)
94136

95137
@property
96138
def name(self) -> str: # type: ignore[override]

pandas/core/internals/blocks.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import numpy as np
1616

1717
from pandas._libs import (
18-
Timestamp,
1918
internals as libinternals,
2019
lib,
2120
writers,
@@ -63,6 +62,7 @@
6362
is_string_dtype,
6463
)
6564
from pandas.core.dtypes.dtypes import (
65+
DatetimeTZDtype,
6666
ExtensionDtype,
6767
PandasDtype,
6868
PeriodDtype,
@@ -2236,17 +2236,16 @@ def get_block_type(dtype: DtypeObj):
22362236
-------
22372237
cls : class, subclass of Block
22382238
"""
2239-
# We use vtype and kind checks because they are much more performant
2239+
# We use kind checks because it is much more performant
22402240
# than is_foo_dtype
2241-
vtype = dtype.type
22422241
kind = dtype.kind
22432242

22442243
cls: type[Block]
22452244

22462245
if isinstance(dtype, SparseDtype):
22472246
# Need this first(ish) so that Sparse[datetime] is sparse
22482247
cls = ExtensionBlock
2249-
elif vtype is Timestamp:
2248+
elif isinstance(dtype, DatetimeTZDtype):
22502249
cls = DatetimeTZBlock
22512250
elif isinstance(dtype, PeriodDtype):
22522251
cls = NDArrayBackedExtensionBlock

pandas/tests/extension/test_arrow.py

+1-33
Original file line numberDiff line numberDiff line change
@@ -324,39 +324,7 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
324324

325325

326326
class TestGetitemTests(base.BaseGetitemTests):
327-
def test_getitem_scalar(self, data):
328-
# In the base class we expect data.dtype.type; but this (intentionally)
329-
# returns Python scalars or pd.NA
330-
pa_type = data._data.type
331-
if pa.types.is_integer(pa_type):
332-
exp_type = int
333-
elif pa.types.is_floating(pa_type):
334-
exp_type = float
335-
elif pa.types.is_string(pa_type):
336-
exp_type = str
337-
elif pa.types.is_binary(pa_type):
338-
exp_type = bytes
339-
elif pa.types.is_boolean(pa_type):
340-
exp_type = bool
341-
elif pa.types.is_duration(pa_type):
342-
exp_type = timedelta
343-
elif pa.types.is_timestamp(pa_type):
344-
if pa_type.unit == "ns":
345-
exp_type = pd.Timestamp
346-
else:
347-
exp_type = datetime
348-
elif pa.types.is_date(pa_type):
349-
exp_type = date
350-
elif pa.types.is_time(pa_type):
351-
exp_type = time
352-
else:
353-
raise NotImplementedError(data.dtype)
354-
355-
result = data[0]
356-
assert isinstance(result, exp_type), type(result)
357-
358-
result = pd.Series(data)[0]
359-
assert isinstance(result, exp_type), type(result)
327+
pass
360328

361329

362330
class TestBaseAccumulateTests(base.BaseAccumulateTests):

0 commit comments

Comments
 (0)