Skip to content

Commit 8f8e3a8

Browse files
authored
BUG/TST: ArrowStringArray skips/xfails (#44795)
1 parent 18de3ac commit 8f8e3a8

File tree

6 files changed

+58
-36
lines changed

6 files changed

+58
-36
lines changed

pandas/core/arrays/string_arrow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _cmp_method(self, other, op):
367367
pc_func = ARROW_CMP_FUNCS[op.__name__]
368368
if isinstance(other, ArrowStringArray):
369369
result = pc_func(self._data, other._data)
370-
elif isinstance(other, np.ndarray):
370+
elif isinstance(other, (np.ndarray, list)):
371371
result = pc_func(self._data, other)
372372
elif is_scalar(other):
373373
try:

pandas/tests/arrays/string_/test_string.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,18 @@ def test_comparison_methods_scalar_pd_na(comparison_op, dtype):
217217
tm.assert_extension_array_equal(result, expected)
218218

219219

220-
def test_comparison_methods_scalar_not_string(comparison_op, dtype, request):
220+
def test_comparison_methods_scalar_not_string(comparison_op, dtype):
221221
op_name = f"__{comparison_op.__name__}__"
222-
if op_name not in ["__eq__", "__ne__"]:
223-
reason = "comparison op not supported between instances of 'str' and 'int'"
224-
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
225-
request.node.add_marker(mark)
226222

227223
a = pd.array(["a", None, "c"], dtype=dtype)
228224
other = 42
225+
226+
if op_name not in ["__eq__", "__ne__"]:
227+
with pytest.raises(TypeError, match="not supported between"):
228+
getattr(a, op_name)(other)
229+
230+
return
231+
229232
result = getattr(a, op_name)(other)
230233
expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[
231234
op_name
@@ -234,12 +237,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype, request):
234237
tm.assert_extension_array_equal(result, expected)
235238

236239

237-
def test_comparison_methods_array(comparison_op, dtype, request):
238-
if dtype.storage == "pyarrow":
239-
mark = pytest.mark.xfail(
240-
raises=AssertionError, reason="left is not an ExtensionArray"
241-
)
242-
request.node.add_marker(mark)
240+
def test_comparison_methods_array(comparison_op, dtype):
243241

244242
op_name = f"__{comparison_op.__name__}__"
245243

@@ -340,6 +338,17 @@ def test_reduce(skipna, dtype):
340338
assert result == "abc"
341339

342340

341+
@pytest.mark.parametrize("skipna", [True, False])
342+
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
343+
def test_reduce_missing(skipna, dtype):
344+
arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype)
345+
result = arr.sum(skipna=skipna)
346+
if skipna:
347+
assert result == "abc"
348+
else:
349+
assert pd.isna(result)
350+
351+
343352
@pytest.mark.parametrize("method", ["min", "max"])
344353
@pytest.mark.parametrize("skipna", [True, False])
345354
def test_min_max(method, skipna, dtype, request):
@@ -374,17 +383,6 @@ def test_min_max_numpy(method, box, dtype, request):
374383
assert result == expected
375384

376385

377-
@pytest.mark.parametrize("skipna", [True, False])
378-
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
379-
def test_reduce_missing(skipna, dtype):
380-
arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype)
381-
result = arr.sum(skipna=skipna)
382-
if skipna:
383-
assert result == "abc"
384-
else:
385-
assert pd.isna(result)
386-
387-
388386
def test_fillna_args(dtype, request):
389387
# GH 37987
390388

pandas/tests/extension/arrow/arrays.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from pandas.api.types import is_scalar
2828
from pandas.core.arraylike import OpsMixin
29+
from pandas.core.construction import extract_array
2930

3031

3132
@register_extension_dtype
@@ -77,6 +78,16 @@ class ArrowExtensionArray(OpsMixin, ExtensionArray):
7778

7879
@classmethod
7980
def from_scalars(cls, values):
81+
if isinstance(values, cls):
82+
# in particular for empty cases the pa.array(np.asarray(...))
83+
# does not round-trip
84+
return cls(values._data)
85+
86+
elif not len(values):
87+
if isinstance(values, list):
88+
dtype = bool if cls is ArrowBoolArray else str
89+
values = np.array([], dtype=dtype)
90+
8091
arr = pa.chunked_array([pa.array(np.asarray(values))])
8192
return cls(arr)
8293

@@ -92,6 +103,14 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
92103
def __repr__(self):
93104
return f"{type(self).__name__}({repr(self._data)})"
94105

106+
def __contains__(self, obj) -> bool:
107+
if obj is None or obj is self.dtype.na_value:
108+
# None -> EA.__contains__ only checks for self._dtype.na_value, not
109+
# any compatible NA value.
110+
# self.dtype.na_value -> <pa.NullScalar:None> isn't recognized by pd.isna
111+
return bool(self.isna().any())
112+
return bool(super().__contains__(obj))
113+
95114
def __getitem__(self, item):
96115
if is_scalar(item):
97116
return self._data.to_pandas()[item]
@@ -125,7 +144,8 @@ def _logical_method(self, other, op):
125144

126145
def __eq__(self, other):
127146
if not isinstance(other, type(self)):
128-
return False
147+
# TODO: use some pyarrow function here?
148+
return np.asarray(self).__eq__(other)
129149

130150
return self._logical_method(other, operator.eq)
131151

@@ -144,6 +164,7 @@ def isna(self):
144164

145165
def take(self, indices, allow_fill=False, fill_value=None):
146166
data = self._data.to_pandas()
167+
data = extract_array(data, extract_numpy=True)
147168

148169
if allow_fill and fill_value is None:
149170
fill_value = self.dtype.na_value

pandas/tests/extension/arrow/test_bool.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def test_view(self, data):
5454
data.view()
5555

5656
@pytest.mark.xfail(
57-
raises=AttributeError,
58-
reason="__eq__ incorrectly returns bool instead of ndarray[bool]",
57+
raises=AssertionError,
58+
reason="Doesn't recognize data._na_value as NA",
5959
)
6060
def test_contains(self, data, data_missing):
6161
super().test_contains(data, data_missing)
@@ -77,7 +77,7 @@ def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
7777
# pyarrow.lib.ArrowInvalid: only handle 1-dimensional arrays
7878
super().test_series_constructor_scalar_na_with_index(dtype, na_value)
7979

80-
@pytest.mark.xfail(reason="raises AssertionError")
80+
@pytest.mark.xfail(reason="ufunc 'invert' not supported for the input types")
8181
def test_construct_empty_dataframe(self, dtype):
8282
super().test_construct_empty_dataframe(dtype)
8383

pandas/tests/extension/test_string.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
146146
if op_name in ["min", "max"]:
147147
return None
148148

149-
s = pd.Series(data)
149+
ser = pd.Series(data)
150150
with pytest.raises(TypeError):
151-
getattr(s, op_name)(skipna=skipna)
151+
getattr(ser, op_name)(skipna=skipna)
152152

153153

154154
class TestMethods(base.BaseMethodsTests):
@@ -166,15 +166,15 @@ class TestCasting(base.BaseCastingTests):
166166

167167

168168
class TestComparisonOps(base.BaseComparisonOpsTests):
169-
def _compare_other(self, s, data, op, other):
169+
def _compare_other(self, ser, data, op, other):
170170
op_name = f"__{op.__name__}__"
171-
result = getattr(s, op_name)(other)
172-
expected = getattr(s.astype(object), op_name)(other).astype("boolean")
171+
result = getattr(ser, op_name)(other)
172+
expected = getattr(ser.astype(object), op_name)(other).astype("boolean")
173173
self.assert_series_equal(result, expected)
174174

175175
def test_compare_scalar(self, data, comparison_op):
176-
s = pd.Series(data)
177-
self._compare_other(s, data, comparison_op, "abc")
176+
ser = pd.Series(data)
177+
self._compare_other(ser, data, comparison_op, "abc")
178178

179179

180180
class TestParsing(base.BaseParsingTests):

pandas/tests/strings/test_string_array.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212

1313
def test_string_array(nullable_string_dtype, any_string_method):
1414
method_name, args, kwargs = any_string_method
15-
if method_name == "decode":
16-
pytest.skip("decode requires bytes.")
1715

1816
data = ["a", "bb", np.nan, "ccc"]
1917
a = Series(data, dtype=object)
2018
b = Series(data, dtype=nullable_string_dtype)
2119

20+
if method_name == "decode":
21+
with pytest.raises(TypeError, match="a bytes-like object is required"):
22+
getattr(b.str, method_name)(*args, **kwargs)
23+
return
24+
2225
expected = getattr(a.str, method_name)(*args, **kwargs)
2326
result = getattr(b.str, method_name)(*args, **kwargs)
2427

0 commit comments

Comments
 (0)