diff --git a/asv_bench/benchmarks/algos/isin.py b/asv_bench/benchmarks/algos/isin.py index a8b8a193dbcfc..44245295beafc 100644 --- a/asv_bench/benchmarks/algos/isin.py +++ b/asv_bench/benchmarks/algos/isin.py @@ -9,6 +9,8 @@ date_range, ) +from ..pandas_vb_common import tm + class IsIn: @@ -22,6 +24,9 @@ class IsIn: "datetime64[ns]", "category[object]", "category[int]", + "str", + "string", + "arrow_string", ] param_names = ["dtype"] @@ -57,6 +62,15 @@ def setup(self, dtype): self.values = np.random.choice(arr, sample_size) self.series = Series(arr).astype("category") + elif dtype in ["str", "string", "arrow_string"]: + from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401 + + try: + self.series = Series(tm.makeStringIndex(N), dtype=dtype) + except ImportError: + raise NotImplementedError + self.values = list(self.series[:2]) + else: self.series = Series(np.random.randint(1, 10, N)).astype(dtype) self.values = [1, 2] diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 2c4477056a112..c52105b77e4dc 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -468,10 +468,9 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray: comps = _ensure_arraylike(comps) comps = extract_array(comps, extract_numpy=True) - if is_extension_array_dtype(comps.dtype): - # error: Incompatible return value type (got "Series", expected "ndarray") - # error: Item "ndarray" of "Union[Any, ndarray]" has no attribute "isin" - return comps.isin(values) # type: ignore[return-value,union-attr] + if not isinstance(comps, np.ndarray): + # i.e. Extension Array + return comps.isin(values) elif needs_i8_conversion(comps.dtype): # Dispatch to DatetimeLikeArrayMixin.isin diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 72a2ab8a1b80a..01813cef97b8d 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -663,6 +663,34 @@ def take( indices_array[indices_array < 0] += len(self._data) return type(self)(self._data.take(indices_array)) + def isin(self, values): + + # pyarrow.compute.is_in added in pyarrow 2.0.0 + if not hasattr(pc, "is_in"): + return super().isin(values) + + value_set = [ + pa_scalar.as_py() + for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values] + if pa_scalar.type in (pa.string(), pa.null()) + ] + + # for an empty value_set pyarrow 3.0.0 segfaults and pyarrow 2.0.0 returns True + # for null values, so we short-circuit to return all False array. + if not len(value_set): + return np.zeros(len(self), dtype=bool) + + kwargs = {} + if LooseVersion(pa.__version__) < "3.0.0": + # in pyarrow 2.0.0 skip_null is ignored but is a required keyword and raises + # with unexpected keyword argument in pyarrow 3.0.0+ + kwargs["skip_null"] = True + + result = pc.is_in(self._data, value_set=pa.array(value_set), **kwargs) + # pyarrow 2.0.0 returned nulls, so we explicily specify dtype to convert nulls + # to False + return np.array(result, dtype=np.bool_) + def value_counts(self, dropna: bool = True) -> Series: """ Return a Series containing counts of each unique value. diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index e2d8e522abb35..43ba5667d4d93 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -566,3 +566,23 @@ def test_to_numpy_na_value(dtype, nulls_fixture): result = arr.to_numpy(na_value=na_value) expected = np.array(["a", na_value, "b"], dtype=object) tm.assert_numpy_array_equal(result, expected) + + +def test_isin(dtype, request): + s = pd.Series(["a", "b", None], dtype=dtype) + + result = s.isin(["a", "c"]) + expected = pd.Series([True, False, False]) + tm.assert_series_equal(result, expected) + + result = s.isin(["a", pd.NA]) + expected = pd.Series([True, False, True]) + tm.assert_series_equal(result, expected) + + result = s.isin([]) + expected = pd.Series([False, False, False]) + tm.assert_series_equal(result, expected) + + result = s.isin(["a", pd.Timestamp.now()]) + expected = pd.Series([True, False, False]) + tm.assert_series_equal(result, expected)