Skip to content

Commit 532b9a1

Browse files
String dtype: fix isin() values handling for python storage (pandas-dev#59759)
* String dtype: fix isin() values handling for python storage * address feedback
1 parent 37886a6 commit 532b9a1

File tree

3 files changed

+64
-6
lines changed

3 files changed

+64
-6
lines changed

pandas/conftest.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1294,7 +1294,13 @@ def string_storage(request):
12941294
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
12951295
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
12961296
("python", np.nan),
1297-
]
1297+
],
1298+
ids=[
1299+
"string=string[python]",
1300+
"string=string[pyarrow]",
1301+
"string=str[pyarrow]",
1302+
"string=str[python]",
1303+
],
12981304
)
12991305
def string_dtype_arguments(request):
13001306
"""
@@ -1325,6 +1331,7 @@ def dtype_backend(request):
13251331

13261332
# Alias so we can test with cartesian product of string_storage
13271333
string_storage2 = string_storage
1334+
string_dtype_arguments2 = string_dtype_arguments
13281335

13291336

13301337
@pytest.fixture(params=tm.BYTES_DTYPES)

pandas/core/arrays/string_.py

+20
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
nanops,
4747
ops,
4848
)
49+
from pandas.core.algorithms import isin
4950
from pandas.core.array_algos import masked_reductions
5051
from pandas.core.arrays.base import ExtensionArray
5152
from pandas.core.arrays.floating import (
@@ -65,6 +66,7 @@
6566
import pyarrow
6667

6768
from pandas._typing import (
69+
ArrayLike,
6870
AxisInt,
6971
Dtype,
7072
DtypeObj,
@@ -733,6 +735,24 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
733735
# base class implementation that uses __setitem__
734736
ExtensionArray._putmask(self, mask, value)
735737

738+
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
739+
if isinstance(values, BaseStringArray) or (
740+
isinstance(values, ExtensionArray) and is_string_dtype(values.dtype)
741+
):
742+
values = values.astype(self.dtype, copy=False)
743+
else:
744+
if not lib.is_string_array(np.asarray(values), skipna=True):
745+
values = np.array(
746+
[val for val in values if isinstance(val, str) or isna(val)],
747+
dtype=object,
748+
)
749+
if not len(values):
750+
return np.zeros(self.shape, dtype=bool)
751+
752+
values = self._from_sequence(values, dtype=self.dtype)
753+
754+
return isin(np.asarray(self), np.asarray(values))
755+
736756
def astype(self, dtype, copy: bool = True):
737757
dtype = pandas_dtype(dtype)
738758

pandas/tests/arrays/string_/test_string.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def dtype(string_dtype_arguments):
2929
return pd.StringDtype(storage=storage, na_value=na_value)
3030

3131

32+
@pytest.fixture
33+
def dtype2(string_dtype_arguments2):
34+
storage, na_value = string_dtype_arguments2
35+
return pd.StringDtype(storage=storage, na_value=na_value)
36+
37+
3238
@pytest.fixture
3339
def cls(dtype):
3440
"""Fixture giving array type from parametrized 'dtype'"""
@@ -689,11 +695,7 @@ def test_isin(dtype, fixed_now_ts):
689695
tm.assert_series_equal(result, expected)
690696

691697
result = s.isin(["a", pd.NA])
692-
if dtype.storage == "python" and dtype.na_value is np.nan:
693-
# TODO(infer_string) we should make this consistent
694-
expected = pd.Series([True, False, False])
695-
else:
696-
expected = pd.Series([True, False, True])
698+
expected = pd.Series([True, False, True])
697699
tm.assert_series_equal(result, expected)
698700

699701
result = s.isin([])
@@ -704,6 +706,35 @@ def test_isin(dtype, fixed_now_ts):
704706
expected = pd.Series([True, False, False])
705707
tm.assert_series_equal(result, expected)
706708

709+
result = s.isin([fixed_now_ts])
710+
expected = pd.Series([False, False, False])
711+
tm.assert_series_equal(result, expected)
712+
713+
714+
def test_isin_string_array(dtype, dtype2):
715+
s = pd.Series(["a", "b", None], dtype=dtype)
716+
717+
result = s.isin(pd.array(["a", "c"], dtype=dtype2))
718+
expected = pd.Series([True, False, False])
719+
tm.assert_series_equal(result, expected)
720+
721+
result = s.isin(pd.array(["a", None], dtype=dtype2))
722+
expected = pd.Series([True, False, True])
723+
tm.assert_series_equal(result, expected)
724+
725+
726+
def test_isin_arrow_string_array(dtype):
727+
pa = pytest.importorskip("pyarrow")
728+
s = pd.Series(["a", "b", None], dtype=dtype)
729+
730+
result = s.isin(pd.array(["a", "c"], dtype=pd.ArrowDtype(pa.string())))
731+
expected = pd.Series([True, False, False])
732+
tm.assert_series_equal(result, expected)
733+
734+
result = s.isin(pd.array(["a", None], dtype=pd.ArrowDtype(pa.string())))
735+
expected = pd.Series([True, False, True])
736+
tm.assert_series_equal(result, expected)
737+
707738

708739
def test_setitem_scalar_with_mask_validation(dtype):
709740
# https://github.com/pandas-dev/pandas/issues/47628

0 commit comments

Comments
 (0)