Skip to content

REF: implement ArrowExtensionArray base class #46102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 26, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions pandas/core/arrays/_arrow_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations

import json
from typing import TypeVar

import numpy as np
import pyarrow

from pandas._typing import npt

from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.interval import VALID_CLOSED


Expand Down Expand Up @@ -139,3 +145,78 @@ def to_pandas_dtype(self):
# register the type with a dummy instance
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
pyarrow.register_extension_type(_interval_type)


ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")


class ArrowExtensionArray(ExtensionArray):
"""
Base class for ExtensionArray backed by Arrow array.
"""

_data: pyarrow.ChunkedArray

def __init__(self, values: pyarrow.ChunkedArray):
raise NotImplementedError

def __arrow_array__(self, type=None):
"""Convert myself to a pyarrow Array or ChunkedArray."""
return self._data

@property
def nbytes(self) -> int:
"""
The number of bytes needed to store this object in memory.
"""
return self._data.nbytes

def __len__(self) -> int:
"""
Length of this array.

Returns
-------
length : int
"""
return len(self._data)

def isna(self) -> npt.NDArray[np.bool_]:
"""
Boolean NumPy array indicating if each value is missing.

This should return a 1-D array the same length as 'self'.
"""
# TODO: Implement .to_numpy for ChunkedArray
return self._data.is_null().to_pandas().values

def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
"""
Return a shallow copy of the array.

Underlying ChunkedArray is immutable, so a deep copy is unnecessary.

Returns
-------
type(self)
"""
return type(self)(self._data)

@classmethod
def _concat_same_type(
cls: type[ArrowExtensionArrayT], to_concat
) -> ArrowExtensionArrayT:
"""
Concatenate multiple ArrowExtensionArrays.

Parameters
----------
to_concat : sequence of ArrowExtensionArrays

Returns
-------
ArrowExtensionArray
"""
chunks = [array for ea in to_concat for array in ea._data.iterchunks()]
arr = pyarrow.chunked_array(chunks)
return cls(arr)
6 changes: 5 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
from pandas.core import ops
from pandas.core.array_algos import masked_reductions
from pandas.core.arrays import (
ExtensionArray,
FloatingArray,
IntegerArray,
PandasArray,
)
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.floating import FloatingDtype
from pandas.core.arrays.integer import IntegerDtype
from pandas.core.construction import extract_array
Expand Down Expand Up @@ -224,6 +224,10 @@ def __from_arrow__(


class BaseStringArray(ExtensionArray):
"""
Mixin class for StringArray, ArrowStringArray.
"""

pass


Expand Down
66 changes: 4 additions & 62 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pandas.core.dtypes.missing import isna

from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._arrow_utils import ArrowExtensionArray
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import Int64Dtype
Expand Down Expand Up @@ -94,7 +95,9 @@ def _chk_pyarrow_available() -> None:
# fallback for the ones that pyarrow doesn't yet support


class ArrowStringArray(OpsMixin, BaseStringArray, ObjectStringArrayMixin):
class ArrowStringArray(
OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin
):
"""
Extension array for string data in a ``pyarrow.ChunkedArray``.

Expand Down Expand Up @@ -191,10 +194,6 @@ def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
"""Correctly construct numpy arrays when passed to `np.asarray()`."""
return self.to_numpy(dtype=dtype)

def __arrow_array__(self, type=None):
"""Convert myself to a pyarrow Array or ChunkedArray."""
return self._data

def to_numpy(
self,
dtype: npt.DTypeLike | None = None,
Expand All @@ -216,16 +215,6 @@ def to_numpy(
result[mask] = na_value
return result

def __len__(self) -> int:
"""
Length of this array.

Returns
-------
length : int
"""
return len(self._data)

@doc(ExtensionArray.factorize)
def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
encoded = self._data.dictionary_encode()
Expand All @@ -243,25 +232,6 @@ def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:

return indices.values, uniques

@classmethod
def _concat_same_type(cls, to_concat) -> ArrowStringArray:
"""
Concatenate multiple ArrowStringArray.

Parameters
----------
to_concat : sequence of ArrowStringArray

Returns
-------
ArrowStringArray
"""
return cls(
pa.chunked_array(
[array for ea in to_concat for array in ea._data.iterchunks()]
)
)

@overload
def __getitem__(self, item: ScalarIndexer) -> ArrowStringScalarOrNAT:
...
Expand Down Expand Up @@ -342,34 +312,6 @@ def _as_pandas_scalar(self, arrow_scalar: pa.Scalar):
else:
return scalar

@property
def nbytes(self) -> int:
"""
The number of bytes needed to store this object in memory.
"""
return self._data.nbytes

def isna(self) -> np.ndarray:
"""
Boolean NumPy array indicating if each value is missing.

This should return a 1-D array the same length as 'self'.
"""
# TODO: Implement .to_numpy for ChunkedArray
return self._data.is_null().to_pandas().values

def copy(self) -> ArrowStringArray:
"""
Return a shallow copy of the array.

Underlying ChunkedArray is immutable, so a deep copy is unnecessary.

Returns
-------
ArrowStringArray
"""
return type(self)(self._data)

def _cmp_method(self, other, op):
from pandas.arrays import BooleanArray

Expand Down
24 changes: 2 additions & 22 deletions pandas/tests/extension/arrow/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""
from __future__ import annotations

import copy
import itertools
import operator

Expand All @@ -19,13 +18,13 @@

import pandas as pd
from pandas.api.extensions import (
ExtensionArray,
ExtensionDtype,
register_extension_dtype,
take,
)
from pandas.api.types import is_scalar
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._arrow_utils import ArrowExtensionArray as _ArrowExtensionArray
from pandas.core.construction import extract_array


Expand Down Expand Up @@ -73,7 +72,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]:
return ArrowStringArray


class ArrowExtensionArray(OpsMixin, ExtensionArray):
class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray):
_data: pa.ChunkedArray

@classmethod
Expand Down Expand Up @@ -111,9 +110,6 @@ def __getitem__(self, item):
vals = self._data.to_pandas()[item]
return type(self)._from_sequence(vals)

def __len__(self):
return len(self._data)

def astype(self, dtype, copy=True):
# needed to fix this astype for the Series constructor.
if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
Expand Down Expand Up @@ -142,19 +138,6 @@ def __eq__(self, other):

return self._logical_method(other, operator.eq)

@property
def nbytes(self) -> int:
return sum(
x.size
for chunk in self._data.chunks
for x in chunk.buffers()
if x is not None
)

def isna(self):
nas = pd.isna(self._data.to_pandas())
return type(self)._from_sequence(nas)

def take(self, indices, allow_fill=False, fill_value=None):
data = self._data.to_pandas()
data = extract_array(data, extract_numpy=True)
Expand All @@ -165,9 +148,6 @@ def take(self, indices, allow_fill=False, fill_value=None):
result = take(data, indices, fill_value=fill_value, allow_fill=allow_fill)
return self._from_sequence(result, dtype=self.dtype)

def copy(self):
return type(self)(copy.copy(self._data))

@classmethod
def _concat_same_type(cls, to_concat):
chunks = list(itertools.chain.from_iterable(x._data.chunks for x in to_concat))
Expand Down