Skip to content

Commit 846afff

Browse files
committed
Fix ArrowExtensionArray and add whatsnew
1 parent 5ba3577 commit 846afff

File tree

4 files changed

+37
-8
lines changed

4 files changed

+37
-8
lines changed

doc/source/whatsnew/v2.3.0.rst

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,20 @@ Notable bug fixes
5050

5151
These are bug fixes that might have notable behavior changes.
5252

53-
.. _whatsnew_230.notable_bug_fixes.notable_bug_fix1:
53+
.. _whatsnew_230.notable_bug_fixes.string_comparisons:
5454

55-
notable_bug_fix1
56-
^^^^^^^^^^^^^^^^
55+
Comparisons between different string dtypes
56+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
57+
58+
In previous versions, comparing Series of different string dtypes (e.g. ``pd.StringDtype("pyarrow", na_value=pd.NA)`` against ``pd.StringDtype("python", na_value=np.nan)``) would result in inconsistent resulting dtype or incorrectly raise. pandas will now use the hierarchy
59+
60+
object < (python, NaN) < (pyarrow, NaN) < (python, NA) < (pyarrow, NA)
61+
62+
in determining the result dtype when there are different string dtypes compared. Some examples:
63+
64+
- When ``pd.StringDtype("pyarrow", na_value=pd.NA)`` is compared against any other string dtype, the result will always be ``boolean[pyarrow]``.
65+
- When ``pd.StringDtype("python", na_value=pd.NA)`` is compared against ``pd.StringDtype("pyarrow", na_value=np.nan)``, the result will be ``boolean``, the NumPy-backed nullable extension array.
66+
- When ``pd.StringDtype("python", na_value=pd.NA)`` is compared against ``pd.StringDtype("python", na_value=np.nan)``, the result will be ``boolean``, the NumPy-backed nullable extension array.
5767

5868
.. _whatsnew_230.api_changes:
5969

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def value_counts(self, dropna: bool = True) -> Series:
474474

475475
def _cmp_method(self, other, op):
476476
if (
477-
isinstance(other, BaseStringArray)
477+
isinstance(other, (BaseStringArray, ArrowExtensionArray))
478478
and self.dtype.na_value is not libmissing.NA
479479
and other.dtype.na_value is libmissing.NA
480480
):

pandas/tests/arrays/string_/test_string.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
pa_version_under12p0,
1616
pa_version_under19p0,
1717
)
18+
import pandas.util._test_decorators as td
1819

1920
from pandas.core.dtypes.common import is_dtype_equal
2021

@@ -371,6 +372,26 @@ def test_comparison_methods_array(comparison_op, dtype, dtype2):
371372
tm.assert_extension_array_equal(result, expected)
372373

373374

375+
@td.skip_if_no("pyarrow")
376+
def test_comparison_methods_array_arrow_extension(comparison_op, dtype2):
377+
# Test pd.ArrowDtype(pa.string()) against other string arrays
378+
import pyarrow as pa
379+
380+
op_name = f"__{comparison_op.__name__}__"
381+
dtype = pd.ArrowDtype(pa.string())
382+
a = pd.array(["a", None, "c"], dtype=dtype)
383+
other = pd.array([None, None, "c"], dtype=dtype2)
384+
result = comparison_op(a, other)
385+
386+
# ensure operation is commutative
387+
result2 = comparison_op(other, a)
388+
tm.assert_equal(result, result2)
389+
390+
expected = pd.array([None, None, True], dtype="bool[pyarrow]")
391+
expected[-1] = getattr(other[-1], op_name)(a[-1])
392+
tm.assert_extension_array_equal(result, expected)
393+
394+
374395
def test_comparison_methods_list(comparison_op, dtype):
375396
op_name = f"__{comparison_op.__name__}__"
376397

pandas/tests/extension/test_string.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,8 @@ def test_arith_series_with_array(
240240
if (
241241
using_infer_string
242242
and all_arithmetic_operators == "__radd__"
243-
and (
244-
dtype.na_value is pd.NA
245-
and not (not HAS_PYARROW and dtype.storage == "python")
246-
)
243+
and dtype.na_value is pd.NA
244+
and (HAS_PYARROW or dtype.storage != "pyarrow")
247245
):
248246
# TODO(infer_string)
249247
mark = pytest.mark.xfail(

0 commit comments

Comments
 (0)