Skip to content

Commit 78d8b2e

Browse files
authored
TST: use single-class pattern in test_string.py (#56509)
* TST: use single-class pattern in test_string.py * mypy fixup * update pyarrow test * __future__ import * mypy fixup
1 parent 55639a3 commit 78d8b2e

File tree

3 files changed

+69
-48
lines changed

3 files changed

+69
-48
lines changed

pandas/core/arrays/arrow/array.py

+3
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,9 @@ def __invert__(self) -> Self:
632632
# This is a bit wise op for integer types
633633
if pa.types.is_integer(self._pa_array.type):
634634
return type(self)(pc.bit_wise_not(self._pa_array))
635+
elif pa.types.is_string(self._pa_array.type):
636+
# Raise TypeError instead of pa.ArrowNotImplementedError
637+
raise TypeError("__invert__ is not supported for string dtypes")
635638
else:
636639
return type(self)(pc.invert(self._pa_array))
637640

pandas/tests/extension/test_arrow.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,11 @@ def test_EA_types(self, engine, data, dtype_backend, request):
753753

754754
def test_invert(self, data, request):
755755
pa_dtype = data.dtype.pyarrow_dtype
756-
if not (pa.types.is_boolean(pa_dtype) or pa.types.is_integer(pa_dtype)):
756+
if not (
757+
pa.types.is_boolean(pa_dtype)
758+
or pa.types.is_integer(pa_dtype)
759+
or pa.types.is_string(pa_dtype)
760+
):
757761
request.applymarker(
758762
pytest.mark.xfail(
759763
raises=pa.ArrowNotImplementedError,

pandas/tests/extension/test_string.py

+61-47
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
be added to the array-specific tests in `pandas/tests/arrays/`.
1414
1515
"""
16+
from __future__ import annotations
17+
1618
import string
19+
from typing import cast
1720

1821
import numpy as np
1922
import pytest
@@ -90,7 +93,7 @@ def data_for_grouping(dtype, chunked):
9093
return maybe_split_array(arr, chunked)
9194

9295

93-
class TestDtype(base.BaseDtypeTests):
96+
class TestStringArray(base.ExtensionTests):
9497
def test_eq_with_str(self, dtype):
9598
assert dtype == f"string[{dtype.storage}]"
9699
super().test_eq_with_str(dtype)
@@ -100,43 +103,25 @@ def test_is_not_string_type(self, dtype):
100103
# because StringDtype is a string type
101104
assert is_string_dtype(dtype)
102105

103-
104-
class TestInterface(base.BaseInterfaceTests):
105106
def test_view(self, data, request, arrow_string_storage):
106107
if data.dtype.storage in arrow_string_storage:
107108
pytest.skip(reason="2D support not implemented for ArrowStringArray")
108109
super().test_view(data)
109110

110-
111-
class TestConstructors(base.BaseConstructorsTests):
112111
def test_from_dtype(self, data):
113112
# base test uses string representation of dtype
114113
pass
115114

116-
117-
class TestReshaping(base.BaseReshapingTests):
118115
def test_transpose(self, data, request, arrow_string_storage):
119116
if data.dtype.storage in arrow_string_storage:
120117
pytest.skip(reason="2D support not implemented for ArrowStringArray")
121118
super().test_transpose(data)
122119

123-
124-
class TestGetitem(base.BaseGetitemTests):
125-
pass
126-
127-
128-
class TestSetitem(base.BaseSetitemTests):
129120
def test_setitem_preserves_views(self, data, request, arrow_string_storage):
130121
if data.dtype.storage in arrow_string_storage:
131122
pytest.skip(reason="2D support not implemented for ArrowStringArray")
132123
super().test_setitem_preserves_views(data)
133124

134-
135-
class TestIndex(base.BaseIndexTests):
136-
pass
137-
138-
139-
class TestMissing(base.BaseMissingTests):
140125
def test_dropna_array(self, data_missing):
141126
result = data_missing.dropna()
142127
expected = data_missing[[1]]
@@ -154,51 +139,80 @@ def test_fillna_no_op_returns_copy(self, data):
154139
assert result is not data
155140
tm.assert_extension_array_equal(result, data)
156141

142+
def _get_expected_exception(
143+
self, op_name: str, obj, other
144+
) -> type[Exception] | None:
145+
if op_name in ["__divmod__", "__rdivmod__"]:
146+
if isinstance(obj, pd.Series) and cast(
147+
StringDtype, tm.get_dtype(obj)
148+
).storage in [
149+
"pyarrow",
150+
"pyarrow_numpy",
151+
]:
152+
# TODO: re-raise as TypeError?
153+
return NotImplementedError
154+
elif isinstance(other, pd.Series) and cast(
155+
StringDtype, tm.get_dtype(other)
156+
).storage in [
157+
"pyarrow",
158+
"pyarrow_numpy",
159+
]:
160+
# TODO: re-raise as TypeError?
161+
return NotImplementedError
162+
return TypeError
163+
elif op_name in ["__mod__", "__rmod__", "__pow__", "__rpow__"]:
164+
if cast(StringDtype, tm.get_dtype(obj)).storage in [
165+
"pyarrow",
166+
"pyarrow_numpy",
167+
]:
168+
return NotImplementedError
169+
return TypeError
170+
elif op_name in ["__mul__", "__rmul__"]:
171+
# Can only multiply strings by integers
172+
return TypeError
173+
elif op_name in [
174+
"__truediv__",
175+
"__rtruediv__",
176+
"__floordiv__",
177+
"__rfloordiv__",
178+
"__sub__",
179+
"__rsub__",
180+
]:
181+
if cast(StringDtype, tm.get_dtype(obj)).storage in [
182+
"pyarrow",
183+
"pyarrow_numpy",
184+
]:
185+
import pyarrow as pa
186+
187+
# TODO: better to re-raise as TypeError?
188+
return pa.ArrowNotImplementedError
189+
return TypeError
190+
191+
return None
157192

158-
class TestReduce(base.BaseReduceTests):
159193
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
160194
return (
161195
op_name in ["min", "max"]
162196
or ser.dtype.storage == "pyarrow_numpy" # type: ignore[union-attr]
163197
and op_name in ("any", "all")
164198
)
165199

166-
167-
class TestMethods(base.BaseMethodsTests):
168-
pass
169-
170-
171-
class TestCasting(base.BaseCastingTests):
172-
pass
173-
174-
175-
class TestComparisonOps(base.BaseComparisonOpsTests):
176200
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
177-
dtype = tm.get_dtype(obj)
178-
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
179-
# attribute "storage"
180-
if dtype.storage == "pyarrow": # type: ignore[union-attr]
181-
cast_to = "boolean[pyarrow]"
182-
elif dtype.storage == "pyarrow_numpy": # type: ignore[union-attr]
201+
dtype = cast(StringDtype, tm.get_dtype(obj))
202+
if op_name in ["__add__", "__radd__"]:
203+
cast_to = dtype
204+
elif dtype.storage == "pyarrow":
205+
cast_to = "boolean[pyarrow]" # type: ignore[assignment]
206+
elif dtype.storage == "pyarrow_numpy":
183207
cast_to = np.bool_ # type: ignore[assignment]
184208
else:
185-
cast_to = "boolean"
209+
cast_to = "boolean" # type: ignore[assignment]
186210
return pointwise_result.astype(cast_to)
187211

188212
def test_compare_scalar(self, data, comparison_op):
189213
ser = pd.Series(data)
190214
self._compare_other(ser, data, comparison_op, "abc")
191215

192-
193-
class TestParsing(base.BaseParsingTests):
194-
pass
195-
196-
197-
class TestPrinting(base.BasePrintingTests):
198-
pass
199-
200-
201-
class TestGroupBy(base.BaseGroupbyTests):
202216
@pytest.mark.filterwarnings("ignore:Falling back:pandas.errors.PerformanceWarning")
203217
def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op):
204218
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)

0 commit comments

Comments
 (0)