Skip to content

Commit f30c7d7

Browse files
authored
ENH: Incorproate ArrowDtype into ArrowExtensionArray (#47034)
1 parent fcd94b3 commit f30c7d7

File tree

11 files changed

+443
-140
lines changed

11 files changed

+443
-140
lines changed

pandas/_testing/__init__.py

+40
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
from pandas._typing import Dtype
29+
from pandas.compat import pa_version_under1p01
2930

3031
from pandas.core.dtypes.common import (
3132
is_float_dtype,
@@ -192,6 +193,45 @@
192193
]
193194
]
194195

196+
if not pa_version_under1p01:
197+
import pyarrow as pa
198+
199+
UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
200+
SIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.int16(), pa.int32(), pa.uint64()]
201+
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
202+
203+
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
204+
STRING_PYARROW_DTYPES = [pa.string(), pa.utf8()]
205+
206+
TIME_PYARROW_DTYPES = [
207+
pa.time32("s"),
208+
pa.time32("ms"),
209+
pa.time64("us"),
210+
pa.time64("ns"),
211+
]
212+
DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()]
213+
DATETIME_PYARROW_DTYPES = [
214+
pa.timestamp(unit=unit, tz=tz)
215+
for unit in ["s", "ms", "us", "ns"]
216+
for tz in [None, "UTC", "US/Pacific", "US/Eastern"]
217+
]
218+
TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]]
219+
220+
BOOL_PYARROW_DTYPES = [pa.bool_()]
221+
222+
# TODO: Add container like pyarrow types:
223+
# https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
224+
ALL_PYARROW_DTYPES = (
225+
ALL_INT_PYARROW_DTYPES
226+
+ FLOAT_PYARROW_DTYPES
227+
+ TIME_PYARROW_DTYPES
228+
+ DATE_PYARROW_DTYPES
229+
+ DATETIME_PYARROW_DTYPES
230+
+ TIMEDELTA_PYARROW_DTYPES
231+
+ BOOL_PYARROW_DTYPES
232+
)
233+
234+
195235
EMPTY_STRING_PATTERN = re.compile("^$")
196236

197237
# set testing_mode

pandas/core/arrays/arrow/array.py

+132-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import numpy as np
1010

1111
from pandas._typing import (
12+
Dtype,
13+
PositionalIndexer,
1214
TakeIndexer,
1315
npt,
1416
)
@@ -24,13 +26,15 @@
2426
is_array_like,
2527
is_bool_dtype,
2628
is_integer,
29+
is_integer_dtype,
2730
is_scalar,
2831
)
2932
from pandas.core.dtypes.missing import isna
3033

3134
from pandas.core.arrays.base import ExtensionArray
3235
from pandas.core.indexers import (
3336
check_array_indexer,
37+
unpack_tuple_and_ellipses,
3438
validate_indices,
3539
)
3640

@@ -39,6 +43,7 @@
3943
import pyarrow.compute as pc
4044

4145
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
46+
from pandas.core.arrays.arrow.dtype import ArrowDtype
4247

4348
if TYPE_CHECKING:
4449
from pandas import Series
@@ -48,16 +53,130 @@
4853

4954
class ArrowExtensionArray(ExtensionArray):
5055
"""
51-
Base class for ExtensionArray backed by Arrow array.
56+
Base class for ExtensionArray backed by Arrow ChunkedArray.
5257
"""
5358

5459
_data: pa.ChunkedArray
5560

56-
def __init__(self, values: pa.ChunkedArray) -> None:
57-
self._data = values
61+
def __init__(self, values: pa.Array | pa.ChunkedArray) -> None:
62+
if pa_version_under1p01:
63+
msg = "pyarrow>=1.0.0 is required for PyArrow backed ArrowExtensionArray."
64+
raise ImportError(msg)
65+
if isinstance(values, pa.Array):
66+
self._data = pa.chunked_array([values])
67+
elif isinstance(values, pa.ChunkedArray):
68+
self._data = values
69+
else:
70+
raise ValueError(
71+
f"Unsupported type '{type(values)}' for ArrowExtensionArray"
72+
)
73+
self._dtype = ArrowDtype(self._data.type)
74+
75+
@classmethod
76+
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
77+
"""
78+
Construct a new ExtensionArray from a sequence of scalars.
79+
"""
80+
if isinstance(dtype, ArrowDtype):
81+
pa_dtype = dtype.pyarrow_dtype
82+
elif dtype:
83+
pa_dtype = pa.from_numpy_dtype(dtype)
84+
else:
85+
pa_dtype = None
86+
87+
if isinstance(scalars, cls):
88+
data = scalars._data
89+
if pa_dtype:
90+
data = data.cast(pa_dtype)
91+
return cls(data)
92+
else:
93+
return cls(
94+
pa.chunked_array(pa.array(scalars, type=pa_dtype, from_pandas=True))
95+
)
96+
97+
@classmethod
98+
def _from_sequence_of_strings(
99+
cls, strings, *, dtype: Dtype | None = None, copy=False
100+
):
101+
"""
102+
Construct a new ExtensionArray from a sequence of strings.
103+
"""
104+
return cls._from_sequence(strings, dtype=dtype, copy=copy)
105+
106+
def __getitem__(self, item: PositionalIndexer):
107+
"""Select a subset of self.
108+
109+
Parameters
110+
----------
111+
item : int, slice, or ndarray
112+
* int: The position in 'self' to get.
113+
* slice: A slice object, where 'start', 'stop', and 'step' are
114+
integers or None
115+
* ndarray: A 1-d boolean NumPy ndarray the same length as 'self'
116+
117+
Returns
118+
-------
119+
item : scalar or ExtensionArray
120+
121+
Notes
122+
-----
123+
For scalar ``item``, return a scalar value suitable for the array's
124+
type. This should be an instance of ``self.dtype.type``.
125+
For slice ``key``, return an instance of ``ExtensionArray``, even
126+
if the slice is length 0 or 1.
127+
For a boolean mask, return an instance of ``ExtensionArray``, filtered
128+
to the values where ``item`` is True.
129+
"""
130+
item = check_array_indexer(self, item)
131+
132+
if isinstance(item, np.ndarray):
133+
if not len(item):
134+
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
135+
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
136+
pa_dtype = pa.string()
137+
else:
138+
pa_dtype = self._dtype.pyarrow_dtype
139+
return type(self)(pa.chunked_array([], type=pa_dtype))
140+
elif is_integer_dtype(item.dtype):
141+
return self.take(item)
142+
elif is_bool_dtype(item.dtype):
143+
return type(self)(self._data.filter(item))
144+
else:
145+
raise IndexError(
146+
"Only integers, slices and integer or "
147+
"boolean arrays are valid indices."
148+
)
149+
elif isinstance(item, tuple):
150+
item = unpack_tuple_and_ellipses(item)
151+
152+
# error: Non-overlapping identity check (left operand type:
153+
# "Union[Union[int, integer[Any]], Union[slice, List[int],
154+
# ndarray[Any, Any]]]", right operand type: "ellipsis")
155+
if item is Ellipsis: # type: ignore[comparison-overlap]
156+
# TODO: should be handled by pyarrow?
157+
item = slice(None)
158+
159+
if is_scalar(item) and not is_integer(item):
160+
# e.g. "foo" or 2.5
161+
# exception message copied from numpy
162+
raise IndexError(
163+
r"only integers, slices (`:`), ellipsis (`...`), numpy.newaxis "
164+
r"(`None`) and integer or boolean arrays are valid indices"
165+
)
166+
# We are not an array indexer, so maybe e.g. a slice or integer
167+
# indexer. We dispatch to pyarrow.
168+
value = self._data[item]
169+
if isinstance(value, pa.ChunkedArray):
170+
return type(self)(value)
171+
else:
172+
scalar = value.as_py()
173+
if scalar is None:
174+
return self._dtype.na_value
175+
else:
176+
return scalar
58177

59178
def __arrow_array__(self, type=None):
60-
"""Convert myself to a pyarrow Array or ChunkedArray."""
179+
"""Convert myself to a pyarrow ChunkedArray."""
61180
return self._data
62181

63182
def equals(self, other) -> bool:
@@ -67,6 +186,13 @@ def equals(self, other) -> bool:
67186
# TODO: is this documented somewhere?
68187
return self._data == other._data
69188

189+
@property
190+
def dtype(self) -> ArrowDtype:
191+
"""
192+
An instance of 'ExtensionDtype'.
193+
"""
194+
return self._dtype
195+
70196
@property
71197
def nbytes(self) -> int:
72198
"""
@@ -377,7 +503,8 @@ def _indexing_key_to_indices(
377503

378504
def _maybe_convert_setitem_value(self, value):
379505
"""Maybe convert value to be pyarrow compatible."""
380-
raise NotImplementedError()
506+
# TODO: Make more robust like ArrowStringArray._maybe_convert_setitem_value
507+
return value
381508

382509
def _set_via_chunk_iteration(
383510
self, indices: npt.NDArray[np.intp], value: npt.NDArray[Any]

0 commit comments

Comments
 (0)