Skip to content

Commit c548c67

Browse files
committed
StringArray comparisions return BooleanArray
xref pandas-dev#29556
1 parent daa3158 commit c548c67

File tree

5 files changed

+94
-9
lines changed

5 files changed

+94
-9
lines changed

doc/source/user_guide/text.rst

+4
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ l. For ``StringDtype``, :ref:`string accessor methods<api.series.str>`
9494
2. Some string methods, like :meth:`Series.str.decode` are not available
9595
on ``StringArray`` because ``StringArray`` only holds strings, not
9696
bytes.
97+
3. In comparision operations, :class:`StringArray` and ``Series`` backed
98+
by a ``StringArray`` will return a :class:`BooleanArray`, rather than
99+
a ``bool`` or ``object`` dtype array, depending on whether there are
100+
missing values.
97101

98102

99103
Everything else that follows in the rest of this document applies equally to

pandas/core/arrays/string_.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ class StringArray(PandasArray):
120120
copy : bool, default False
121121
Whether to copy the array of data.
122122
123+
Notes
124+
-----
125+
StringArray returns a BooleanArray for comparison methods.
126+
123127
Attributes
124128
----------
125129
None
@@ -148,6 +152,13 @@ class StringArray(PandasArray):
148152
Traceback (most recent call last):
149153
...
150154
ValueError: StringArray requires an object-dtype ndarray of strings.
155+
156+
For comparision methods, this returns a :class:`pandas.BooleanArray`
157+
158+
>>> pd.array(["a", None, "c"], dtype="string") == "a"
159+
<BooleanArray>
160+
[True, NA, False]
161+
Length: 3, dtype: boolean
151162
"""
152163

153164
# undo the PandasArray hack
@@ -255,7 +266,12 @@ def value_counts(self, dropna=False):
255266
# Overrride parent because we have different return types.
256267
@classmethod
257268
def _create_arithmetic_method(cls, op):
269+
# Note: this handles both arithmetic and comparison methods.
258270
def method(self, other):
271+
from pandas.arrays import BooleanArray
272+
273+
assert op.__name__ in ops.ARITHMETIC_BINOPS | ops.COMPARISON_BINOPS
274+
259275
if isinstance(other, (ABCIndexClass, ABCSeries, ABCDataFrame)):
260276
return NotImplemented
261277

@@ -275,15 +291,16 @@ def method(self, other):
275291
other = np.asarray(other)
276292
other = other[valid]
277293

278-
result = np.empty_like(self._ndarray, dtype="object")
279-
result[mask] = StringDtype.na_value
280-
result[valid] = op(self._ndarray[valid], other)
281-
282-
if op.__name__ in {"add", "radd", "mul", "rmul"}:
294+
if op.__name__ in ops.ARITHMETIC_BINOPS:
295+
result = np.empty_like(self._ndarray, dtype="object")
296+
result[mask] = StringDtype.na_value
297+
result[valid] = op(self._ndarray[valid], other)
283298
return StringArray(result)
284299
else:
285-
dtype = "object" if mask.any() else "bool"
286-
return np.asarray(result, dtype=dtype)
300+
# logical
301+
result = np.zeros(len(self._ndarray), dtype="bool")
302+
result[valid] = op(self._ndarray[valid], other)
303+
return BooleanArray(result, mask)
287304

288305
return compat.set_function_name(method, f"__{op.__name__}__", cls)
289306

pandas/core/ops/__init__.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
import datetime
77
import operator
8-
from typing import Tuple, Union
8+
from typing import Set, Tuple, Union
99

1010
import numpy as np
1111

@@ -59,6 +59,37 @@
5959
rxor,
6060
)
6161

62+
# -----------------------------------------------------------------------------
63+
# constants
64+
ARITHMETIC_BINOPS: Set[str] = {
65+
"add",
66+
"sub",
67+
"mul",
68+
"pow",
69+
"mod",
70+
"floordiv",
71+
"truediv",
72+
"divmod",
73+
"radd",
74+
"rsub",
75+
"rmul",
76+
"rpow",
77+
"rmod",
78+
"rfloordiv",
79+
"rtruediv",
80+
"rdivmod",
81+
}
82+
83+
84+
COMPARISON_BINOPS: Set[str] = {
85+
"eq",
86+
"ne",
87+
"lt",
88+
"gt",
89+
"le",
90+
"ge",
91+
}
92+
6293
# -----------------------------------------------------------------------------
6394
# Ops Wrapping Utilities
6495

pandas/tests/arrays/string_/test_string.py

+33
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,39 @@ def test_add_frame():
154154
tm.assert_frame_equal(result, expected)
155155

156156

157+
def test_comparison_methods_scalar(all_compare_operators):
158+
op_name = all_compare_operators
159+
160+
a = pd.array(["a", None, "c"], dtype="string")
161+
other = "a"
162+
result = getattr(a, op_name)(other)
163+
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
164+
expected[1] = None
165+
expected = pd.array(expected, dtype="boolean")
166+
tm.assert_extension_array_equal(result, expected)
167+
168+
result = getattr(a, op_name)(pd.NA)
169+
expected = pd.array([None, None, None], dtype="boolean")
170+
tm.assert_extension_array_equal(result, expected)
171+
172+
173+
def test_comparison_methods_array(all_compare_operators):
174+
op_name = all_compare_operators
175+
176+
a = pd.array(["a", None, "c"], dtype="string")
177+
other = [None, None, "c"]
178+
result = getattr(a, op_name)(other)
179+
expected = np.empty_like(a, dtype="object")
180+
expected[:2] = None
181+
expected[-1] = getattr(other[-1], op_name)(a[-1])
182+
expected = pd.array(expected, dtype="boolean")
183+
tm.assert_extension_array_equal(result, expected)
184+
185+
result = getattr(a, op_name)(pd.NA)
186+
expected = pd.array([None, None, None], dtype="boolean")
187+
tm.assert_extension_array_equal(result, expected)
188+
189+
157190
def test_constructor_raises():
158191
with pytest.raises(ValueError, match="sequence of strings"):
159192
pd.arrays.StringArray(np.array(["a", "b"], dtype="S1"))

pandas/tests/extension/test_string.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class TestCasting(base.BaseCastingTests):
9191
class TestComparisonOps(base.BaseComparisonOpsTests):
9292
def _compare_other(self, s, data, op_name, other):
9393
result = getattr(s, op_name)(other)
94-
expected = getattr(s.astype(object), op_name)(other)
94+
expected = getattr(s.astype(object), op_name)(other).astype("boolean")
9595
self.assert_series_equal(result, expected)
9696

9797
def test_compare_scalar(self, data, all_compare_operators):

0 commit comments

Comments
 (0)