Skip to content

Commit 96b71d1

Browse files
TomAugspurgerproost
authored andcommitted
StringArray comparisions return BooleanArray (pandas-dev#30231)
xref pandas-dev#29556
1 parent 3333eec commit 96b71d1

File tree

5 files changed

+93
-10
lines changed

5 files changed

+93
-10
lines changed

doc/source/user_guide/text.rst

+5-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ l. For ``StringDtype``, :ref:`string accessor methods<api.series.str>`
9494
2. Some string methods, like :meth:`Series.str.decode` are not available
9595
on ``StringArray`` because ``StringArray`` only holds strings, not
9696
bytes.
97-
97+
3. In comparision operations, :class:`arrays.StringArray` and ``Series`` backed
98+
by a ``StringArray`` will return an object with :class:`BooleanDtype`,
99+
rather than a ``bool`` dtype object. Missing values in a ``StringArray``
100+
will propagate in comparision operations, rather than always comparing
101+
unequal like :attr:`numpy.nan`.
98102

99103
Everything else that follows in the rest of this document applies equally to
100104
``string`` and ``object`` dtype.

pandas/core/arrays/string_.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ class StringArray(PandasArray):
134134
The string methods are available on Series backed by
135135
a StringArray.
136136
137+
Notes
138+
-----
139+
StringArray returns a BooleanArray for comparison methods.
140+
137141
Examples
138142
--------
139143
>>> pd.array(['This is', 'some text', None, 'data.'], dtype="string")
@@ -148,6 +152,13 @@ class StringArray(PandasArray):
148152
Traceback (most recent call last):
149153
...
150154
ValueError: StringArray requires an object-dtype ndarray of strings.
155+
156+
For comparision methods, this returns a :class:`pandas.BooleanArray`
157+
158+
>>> pd.array(["a", None, "c"], dtype="string") == "a"
159+
<BooleanArray>
160+
[True, NA, False]
161+
Length: 3, dtype: boolean
151162
"""
152163

153164
# undo the PandasArray hack
@@ -255,7 +266,12 @@ def value_counts(self, dropna=False):
255266
# Overrride parent because we have different return types.
256267
@classmethod
257268
def _create_arithmetic_method(cls, op):
269+
# Note: this handles both arithmetic and comparison methods.
258270
def method(self, other):
271+
from pandas.arrays import BooleanArray
272+
273+
assert op.__name__ in ops.ARITHMETIC_BINOPS | ops.COMPARISON_BINOPS
274+
259275
if isinstance(other, (ABCIndexClass, ABCSeries, ABCDataFrame)):
260276
return NotImplemented
261277

@@ -275,15 +291,16 @@ def method(self, other):
275291
other = np.asarray(other)
276292
other = other[valid]
277293

278-
result = np.empty_like(self._ndarray, dtype="object")
279-
result[mask] = StringDtype.na_value
280-
result[valid] = op(self._ndarray[valid], other)
281-
282-
if op.__name__ in {"add", "radd", "mul", "rmul"}:
294+
if op.__name__ in ops.ARITHMETIC_BINOPS:
295+
result = np.empty_like(self._ndarray, dtype="object")
296+
result[mask] = StringDtype.na_value
297+
result[valid] = op(self._ndarray[valid], other)
283298
return StringArray(result)
284299
else:
285-
dtype = "object" if mask.any() else "bool"
286-
return np.asarray(result, dtype=dtype)
300+
# logical
301+
result = np.zeros(len(self._ndarray), dtype="bool")
302+
result[valid] = op(self._ndarray[valid], other)
303+
return BooleanArray(result, mask)
287304

288305
return compat.set_function_name(method, f"__{op.__name__}__", cls)
289306

pandas/core/ops/__init__.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
import datetime
77
import operator
8-
from typing import Tuple, Union
8+
from typing import Set, Tuple, Union
99

1010
import numpy as np
1111

@@ -59,6 +59,37 @@
5959
rxor,
6060
)
6161

62+
# -----------------------------------------------------------------------------
63+
# constants
64+
ARITHMETIC_BINOPS: Set[str] = {
65+
"add",
66+
"sub",
67+
"mul",
68+
"pow",
69+
"mod",
70+
"floordiv",
71+
"truediv",
72+
"divmod",
73+
"radd",
74+
"rsub",
75+
"rmul",
76+
"rpow",
77+
"rmod",
78+
"rfloordiv",
79+
"rtruediv",
80+
"rdivmod",
81+
}
82+
83+
84+
COMPARISON_BINOPS: Set[str] = {
85+
"eq",
86+
"ne",
87+
"lt",
88+
"gt",
89+
"le",
90+
"ge",
91+
}
92+
6293
# -----------------------------------------------------------------------------
6394
# Ops Wrapping Utilities
6495

pandas/tests/arrays/string_/test_string.py

+31
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,37 @@ def test_add_frame():
154154
tm.assert_frame_equal(result, expected)
155155

156156

157+
def test_comparison_methods_scalar(all_compare_operators):
158+
op_name = all_compare_operators
159+
160+
a = pd.array(["a", None, "c"], dtype="string")
161+
other = "a"
162+
result = getattr(a, op_name)(other)
163+
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
164+
expected = pd.array(expected, dtype="boolean")
165+
tm.assert_extension_array_equal(result, expected)
166+
167+
result = getattr(a, op_name)(pd.NA)
168+
expected = pd.array([None, None, None], dtype="boolean")
169+
tm.assert_extension_array_equal(result, expected)
170+
171+
172+
def test_comparison_methods_array(all_compare_operators):
173+
op_name = all_compare_operators
174+
175+
a = pd.array(["a", None, "c"], dtype="string")
176+
other = [None, None, "c"]
177+
result = getattr(a, op_name)(other)
178+
expected = np.empty_like(a, dtype="object")
179+
expected[-1] = getattr(other[-1], op_name)(a[-1])
180+
expected = pd.array(expected, dtype="boolean")
181+
tm.assert_extension_array_equal(result, expected)
182+
183+
result = getattr(a, op_name)(pd.NA)
184+
expected = pd.array([None, None, None], dtype="boolean")
185+
tm.assert_extension_array_equal(result, expected)
186+
187+
157188
def test_constructor_raises():
158189
with pytest.raises(ValueError, match="sequence of strings"):
159190
pd.arrays.StringArray(np.array(["a", "b"], dtype="S1"))

pandas/tests/extension/test_string.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class TestCasting(base.BaseCastingTests):
9191
class TestComparisonOps(base.BaseComparisonOpsTests):
9292
def _compare_other(self, s, data, op_name, other):
9393
result = getattr(s, op_name)(other)
94-
expected = getattr(s.astype(object), op_name)(other)
94+
expected = getattr(s.astype(object), op_name)(other).astype("boolean")
9595
self.assert_series_equal(result, expected)
9696

9797
def test_compare_scalar(self, data, all_compare_operators):

0 commit comments

Comments
 (0)