Skip to content

ENH/TST: Add BaseInterfaceTests tests for ArrowExtensionArray #47377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 22, 2022
40 changes: 39 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.base import ExtensionArray
from pandas.core.indexers import (
check_array_indexer,
Expand All @@ -45,13 +46,22 @@
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
from pandas.core.arrays.arrow.dtype import ArrowDtype

ARROW_CMP_FUNCS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't there a min version on some of these?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like they all have been around since pyarrow version 2 which I think is the version we require to use these functions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus we catch the except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid) when calling these with fallback behavior.

"eq": pc.equal,
"ne": pc.not_equal,
"lt": pc.less,
"gt": pc.greater,
"le": pc.less_equal,
"ge": pc.greater_equal,
}

if TYPE_CHECKING:
from pandas import Series

ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")


class ArrowExtensionArray(ExtensionArray):
class ArrowExtensionArray(OpsMixin, ExtensionArray):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I be using ExtensionOpsMixin or OpsMixin @jbrockmendel?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i recommend OpsMixin

"""
Base class for ExtensionArray backed by Arrow ChunkedArray.
"""
Expand Down Expand Up @@ -179,6 +189,34 @@ def __arrow_array__(self, type=None):
"""Convert myself to a pyarrow ChunkedArray."""
return self._data

def _cmp_method(self, other, op):
from pandas.arrays import BooleanArray

pc_func = ARROW_CMP_FUNCS[op.__name__]
if isinstance(other, ArrowExtensionArray):
result = pc_func(self._data, other._data)
elif isinstance(other, (np.ndarray, list)):
result = pc_func(self._data, other)
elif is_scalar(other):
try:
result = pc_func(self._data, pa.scalar(other))
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
mask = isna(self) | isna(other)
valid = ~mask
result = np.zeros(len(self), dtype="bool")
result[valid] = op(np.array(self)[valid], other)
return BooleanArray(result, mask)
else:
return NotImplementedError(
f"{op.__name__} not implemented for {type(other)}"
)

if pa_version_under2p0:
result = result.to_pandas().values
else:
result = result.to_numpy()
return BooleanArray._from_sequence(result)

def equals(self, other) -> bool:
if not isinstance(other, ArrowExtensionArray):
return False
Expand Down
40 changes: 1 addition & 39 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.arrow import ArrowExtensionArray
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import Int64Dtype
Expand All @@ -51,15 +50,6 @@

from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning

ARROW_CMP_FUNCS = {
"eq": pc.equal,
"ne": pc.not_equal,
"lt": pc.less,
"gt": pc.greater,
"le": pc.less_equal,
"ge": pc.greater_equal,
}

ArrowStringScalarOrNAT = Union[str, libmissing.NAType]


Expand All @@ -74,9 +64,7 @@ def _chk_pyarrow_available() -> None:
# fallback for the ones that pyarrow doesn't yet support


class ArrowStringArray(
OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin
):
class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin):
"""
Extension array for string data in a ``pyarrow.ChunkedArray``.

Expand Down Expand Up @@ -190,32 +178,6 @@ def to_numpy(
result[mask] = na_value
return result

def _cmp_method(self, other, op):
from pandas.arrays import BooleanArray

pc_func = ARROW_CMP_FUNCS[op.__name__]
if isinstance(other, ArrowStringArray):
result = pc_func(self._data, other._data)
elif isinstance(other, (np.ndarray, list)):
result = pc_func(self._data, other)
elif is_scalar(other):
try:
result = pc_func(self._data, pa.scalar(other))
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
mask = isna(self) | isna(other)
valid = ~mask
result = np.zeros(len(self), dtype="bool")
result[valid] = op(np.array(self)[valid], other)
return BooleanArray(result, mask)
else:
return NotImplemented

if pa_version_under2p0:
result = result.to_pandas().values
else:
result = result.to_numpy()
return BooleanArray._from_sequence(result)

def insert(self, loc: int, item):
if not isinstance(item, str) and item is not libmissing.NA:
raise TypeError("Scalar must be NA or str")
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/arrow/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
take,
)
from pandas.api.types import is_scalar
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.arrow import ArrowExtensionArray as _ArrowExtensionArray
from pandas.core.construction import extract_array

Expand Down Expand Up @@ -72,7 +71,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]:
return ArrowStringArray


class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray):
class ArrowExtensionArray(_ArrowExtensionArray):
_data: pa.ChunkedArray

@classmethod
Expand Down
22 changes: 21 additions & 1 deletion pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pandas.core.arrays.arrow.dtype import ArrowDtype # isort:skip


@pytest.fixture(params=tm.ALL_PYARROW_DTYPES)
@pytest.fixture(params=tm.ALL_PYARROW_DTYPES, ids=str)
def dtype(request):
return ArrowDtype(pyarrow_dtype=request.param)

Expand Down Expand Up @@ -201,6 +201,26 @@ class TestBaseIndex(base.BaseIndexTests):
pass


class TestBaseInterface(base.BaseInterfaceTests):
def test_contains(self, data, data_missing, request):
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
unit = getattr(data.dtype.pyarrow_dtype, "unit", None)
if pa_version_under2p0 and tz not in (None, "UTC") and unit == "us":
request.node.add_marker(
pytest.mark.xfail(
reason=(
f"Not supported by pyarrow < 2.0 "
f"with timestamp type {tz} and {unit}"
)
)
)
super().test_contains(data, data_missing)

@pytest.mark.xfail(reason="pyarrow.ChunkedArray does not support views.")
def test_view(self, data):
super().test_view(data)


def test_arrowdtype_construct_from_string_type_with_parameters():
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("timestamp[s][pyarrow]")