Skip to content

TST: use single-class pattern in test_string.py #56509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,9 @@ def __invert__(self) -> Self:
# This is a bit wise op for integer types
if pa.types.is_integer(self._pa_array.type):
return type(self)(pc.bit_wise_not(self._pa_array))
elif pa.types.is_string(self._pa_array.type):
# Raise TypeError instead of pa.ArrowNotImplementedError
raise TypeError("__invert__ is not supported for string dtypes")
else:
return type(self)(pc.invert(self._pa_array))

Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,11 @@ def test_EA_types(self, engine, data, dtype_backend, request):

def test_invert(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if not (pa.types.is_boolean(pa_dtype) or pa.types.is_integer(pa_dtype)):
if not (
pa.types.is_boolean(pa_dtype)
or pa.types.is_integer(pa_dtype)
or pa.types.is_string(pa_dtype)
):
request.applymarker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand Down
108 changes: 61 additions & 47 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
be added to the array-specific tests in `pandas/tests/arrays/`.

"""
from __future__ import annotations

import string
from typing import cast

import numpy as np
import pytest
Expand Down Expand Up @@ -90,7 +93,7 @@ def data_for_grouping(dtype, chunked):
return maybe_split_array(arr, chunked)


class TestDtype(base.BaseDtypeTests):
class TestStringArray(base.ExtensionTests):
def test_eq_with_str(self, dtype):
assert dtype == f"string[{dtype.storage}]"
super().test_eq_with_str(dtype)
Expand All @@ -100,43 +103,25 @@ def test_is_not_string_type(self, dtype):
# because StringDtype is a string type
assert is_string_dtype(dtype)


class TestInterface(base.BaseInterfaceTests):
def test_view(self, data, request, arrow_string_storage):
if data.dtype.storage in arrow_string_storage:
pytest.skip(reason="2D support not implemented for ArrowStringArray")
super().test_view(data)


class TestConstructors(base.BaseConstructorsTests):
def test_from_dtype(self, data):
# base test uses string representation of dtype
pass


class TestReshaping(base.BaseReshapingTests):
def test_transpose(self, data, request, arrow_string_storage):
if data.dtype.storage in arrow_string_storage:
pytest.skip(reason="2D support not implemented for ArrowStringArray")
super().test_transpose(data)


class TestGetitem(base.BaseGetitemTests):
pass


class TestSetitem(base.BaseSetitemTests):
def test_setitem_preserves_views(self, data, request, arrow_string_storage):
if data.dtype.storage in arrow_string_storage:
pytest.skip(reason="2D support not implemented for ArrowStringArray")
super().test_setitem_preserves_views(data)


class TestIndex(base.BaseIndexTests):
pass


class TestMissing(base.BaseMissingTests):
def test_dropna_array(self, data_missing):
result = data_missing.dropna()
expected = data_missing[[1]]
Expand All @@ -154,51 +139,80 @@ def test_fillna_no_op_returns_copy(self, data):
assert result is not data
tm.assert_extension_array_equal(result, data)

def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
if op_name in ["__divmod__", "__rdivmod__"]:
if isinstance(obj, pd.Series) and cast(
StringDtype, tm.get_dtype(obj)
).storage in [
"pyarrow",
"pyarrow_numpy",
]:
# TODO: re-raise as TypeError?
return NotImplementedError
elif isinstance(other, pd.Series) and cast(
StringDtype, tm.get_dtype(other)
).storage in [
"pyarrow",
"pyarrow_numpy",
]:
# TODO: re-raise as TypeError?
return NotImplementedError
return TypeError
elif op_name in ["__mod__", "__rmod__", "__pow__", "__rpow__"]:
if cast(StringDtype, tm.get_dtype(obj)).storage in [
"pyarrow",
"pyarrow_numpy",
]:
return NotImplementedError
return TypeError
elif op_name in ["__mul__", "__rmul__"]:
# Can only multiply strings by integers
return TypeError
elif op_name in [
"__truediv__",
"__rtruediv__",
"__floordiv__",
"__rfloordiv__",
"__sub__",
"__rsub__",
]:
if cast(StringDtype, tm.get_dtype(obj)).storage in [
"pyarrow",
"pyarrow_numpy",
]:
import pyarrow as pa

# TODO: better to re-raise as TypeError?
return pa.ArrowNotImplementedError
return TypeError

return None

class TestReduce(base.BaseReduceTests):
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
return (
op_name in ["min", "max"]
or ser.dtype.storage == "pyarrow_numpy" # type: ignore[union-attr]
and op_name in ("any", "all")
)


class TestMethods(base.BaseMethodsTests):
pass


class TestCasting(base.BaseCastingTests):
pass


class TestComparisonOps(base.BaseComparisonOpsTests):
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
dtype = tm.get_dtype(obj)
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
# attribute "storage"
if dtype.storage == "pyarrow": # type: ignore[union-attr]
cast_to = "boolean[pyarrow]"
elif dtype.storage == "pyarrow_numpy": # type: ignore[union-attr]
dtype = cast(StringDtype, tm.get_dtype(obj))
if op_name in ["__add__", "__radd__"]:
cast_to = dtype
elif dtype.storage == "pyarrow":
cast_to = "boolean[pyarrow]" # type: ignore[assignment]
elif dtype.storage == "pyarrow_numpy":
cast_to = np.bool_ # type: ignore[assignment]
else:
cast_to = "boolean"
cast_to = "boolean" # type: ignore[assignment]
return pointwise_result.astype(cast_to)

def test_compare_scalar(self, data, comparison_op):
ser = pd.Series(data)
self._compare_other(ser, data, comparison_op, "abc")


class TestParsing(base.BaseParsingTests):
pass


class TestPrinting(base.BasePrintingTests):
pass


class TestGroupBy(base.BaseGroupbyTests):
@pytest.mark.filterwarnings("ignore:Falling back:pandas.errors.PerformanceWarning")
def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op):
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)
Expand Down