Skip to content

ENH/TST: Add BaseInterfaceTests tests for ArrowExtensionArray PT2 #47468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.base import ExtensionArray
from pandas.core.indexers import (
check_array_indexer,
Expand All @@ -45,13 +46,22 @@
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
from pandas.core.arrays.arrow.dtype import ArrowDtype

ARROW_CMP_FUNCS = {
"eq": pc.equal,
"ne": pc.not_equal,
"lt": pc.less,
"gt": pc.greater,
"le": pc.less_equal,
"ge": pc.greater_equal,
}

if TYPE_CHECKING:
from pandas import Series

ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")


class ArrowExtensionArray(ExtensionArray):
class ArrowExtensionArray(OpsMixin, ExtensionArray):
"""
Base class for ExtensionArray backed by Arrow ChunkedArray.
"""
Expand Down Expand Up @@ -179,6 +189,34 @@ def __arrow_array__(self, type=None):
"""Convert myself to a pyarrow ChunkedArray."""
return self._data

def _cmp_method(self, other, op):
from pandas.arrays import BooleanArray

pc_func = ARROW_CMP_FUNCS[op.__name__]
if isinstance(other, ArrowExtensionArray):
result = pc_func(self._data, other._data)
elif isinstance(other, (np.ndarray, list)):
result = pc_func(self._data, other)
elif is_scalar(other):
try:
result = pc_func(self._data, pa.scalar(other))
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
mask = isna(self) | isna(other)
valid = ~mask
result = np.zeros(len(self), dtype="bool")
result[valid] = op(np.array(self)[valid], other)
return BooleanArray(result, mask)
else:
return NotImplementedError(
f"{op.__name__} not implemented for {type(other)}"
)

if pa_version_under2p0:
result = result.to_pandas().values
else:
result = result.to_numpy()
return BooleanArray._from_sequence(result)

def equals(self, other) -> bool:
if not isinstance(other, ArrowExtensionArray):
return False
Expand Down Expand Up @@ -589,7 +627,7 @@ def _replace_with_indices(
# fast path for a contiguous set of indices
arrays = [
chunk[:start],
pa.array(value, type=chunk.type),
pa.array(value, type=chunk.type, from_pandas=True),
chunk[stop + 1 :],
]
arrays = [arr for arr in arrays if len(arr)]
Expand Down
40 changes: 1 addition & 39 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.arrow import ArrowExtensionArray
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import Int64Dtype
Expand All @@ -51,15 +50,6 @@

from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning

ARROW_CMP_FUNCS = {
"eq": pc.equal,
"ne": pc.not_equal,
"lt": pc.less,
"gt": pc.greater,
"le": pc.less_equal,
"ge": pc.greater_equal,
}

ArrowStringScalarOrNAT = Union[str, libmissing.NAType]


Expand All @@ -74,9 +64,7 @@ def _chk_pyarrow_available() -> None:
# fallback for the ones that pyarrow doesn't yet support


class ArrowStringArray(
OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin
):
class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin):
"""
Extension array for string data in a ``pyarrow.ChunkedArray``.

Expand Down Expand Up @@ -190,32 +178,6 @@ def to_numpy(
result[mask] = na_value
return result

def _cmp_method(self, other, op):
from pandas.arrays import BooleanArray

pc_func = ARROW_CMP_FUNCS[op.__name__]
if isinstance(other, ArrowStringArray):
result = pc_func(self._data, other._data)
elif isinstance(other, (np.ndarray, list)):
result = pc_func(self._data, other)
elif is_scalar(other):
try:
result = pc_func(self._data, pa.scalar(other))
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
mask = isna(self) | isna(other)
valid = ~mask
result = np.zeros(len(self), dtype="bool")
result[valid] = op(np.array(self)[valid], other)
return BooleanArray(result, mask)
else:
return NotImplemented

if pa_version_under2p0:
result = result.to_pandas().values
else:
result = result.to_numpy()
return BooleanArray._from_sequence(result)

def insert(self, loc: int, item):
if not isinstance(item, str) and item is not libmissing.NA:
raise TypeError("Scalar must be NA or str")
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/arrow/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
take,
)
from pandas.api.types import is_scalar
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.arrow import ArrowExtensionArray as _ArrowExtensionArray
from pandas.core.construction import extract_array

Expand Down Expand Up @@ -72,7 +71,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]:
return ArrowStringArray


class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray):
class ArrowExtensionArray(_ArrowExtensionArray):
_data: pa.ChunkedArray

@classmethod
Expand Down
Loading