Skip to content

Commit 691a2c4

Browse files
[ArrowStringArray] PERF: isin using native pyarrow function if available (#41281)
1 parent 4d73a34 commit 691a2c4

File tree

4 files changed

+65
-4
lines changed

4 files changed

+65
-4
lines changed

asv_bench/benchmarks/algos/isin.py

+14
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
date_range,
1010
)
1111

12+
from ..pandas_vb_common import tm
13+
1214

1315
class IsIn:
1416

@@ -22,6 +24,9 @@ class IsIn:
2224
"datetime64[ns]",
2325
"category[object]",
2426
"category[int]",
27+
"str",
28+
"string",
29+
"arrow_string",
2530
]
2631
param_names = ["dtype"]
2732

@@ -57,6 +62,15 @@ def setup(self, dtype):
5762
self.values = np.random.choice(arr, sample_size)
5863
self.series = Series(arr).astype("category")
5964

65+
elif dtype in ["str", "string", "arrow_string"]:
66+
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
67+
68+
try:
69+
self.series = Series(tm.makeStringIndex(N), dtype=dtype)
70+
except ImportError:
71+
raise NotImplementedError
72+
self.values = list(self.series[:2])
73+
6074
else:
6175
self.series = Series(np.random.randint(1, 10, N)).astype(dtype)
6276
self.values = [1, 2]

pandas/core/algorithms.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,9 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray:
468468

469469
comps = _ensure_arraylike(comps)
470470
comps = extract_array(comps, extract_numpy=True)
471-
if is_extension_array_dtype(comps.dtype):
472-
# error: Incompatible return value type (got "Series", expected "ndarray")
473-
# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute "isin"
474-
return comps.isin(values) # type: ignore[return-value,union-attr]
471+
if not isinstance(comps, np.ndarray):
472+
# i.e. Extension Array
473+
return comps.isin(values)
475474

476475
elif needs_i8_conversion(comps.dtype):
477476
# Dispatch to DatetimeLikeArrayMixin.isin

pandas/core/arrays/string_arrow.py

+28
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,34 @@ def take(
666666
indices_array[indices_array < 0] += len(self._data)
667667
return type(self)(self._data.take(indices_array))
668668

669+
def isin(self, values):
670+
671+
# pyarrow.compute.is_in added in pyarrow 2.0.0
672+
if not hasattr(pc, "is_in"):
673+
return super().isin(values)
674+
675+
value_set = [
676+
pa_scalar.as_py()
677+
for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values]
678+
if pa_scalar.type in (pa.string(), pa.null())
679+
]
680+
681+
# for an empty value_set pyarrow 3.0.0 segfaults and pyarrow 2.0.0 returns True
682+
# for null values, so we short-circuit to return all False array.
683+
if not len(value_set):
684+
return np.zeros(len(self), dtype=bool)
685+
686+
kwargs = {}
687+
if LooseVersion(pa.__version__) < "3.0.0":
688+
# in pyarrow 2.0.0 skip_null is ignored but is a required keyword and raises
689+
# with unexpected keyword argument in pyarrow 3.0.0+
690+
kwargs["skip_null"] = True
691+
692+
result = pc.is_in(self._data, value_set=pa.array(value_set), **kwargs)
693+
# pyarrow 2.0.0 returned nulls, so we explicily specify dtype to convert nulls
694+
# to False
695+
return np.array(result, dtype=np.bool_)
696+
669697
def value_counts(self, dropna: bool = True) -> Series:
670698
"""
671699
Return a Series containing counts of each unique value.

pandas/tests/arrays/string_/test_string.py

+20
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,23 @@ def test_to_numpy_na_value(dtype, nulls_fixture):
566566
result = arr.to_numpy(na_value=na_value)
567567
expected = np.array(["a", na_value, "b"], dtype=object)
568568
tm.assert_numpy_array_equal(result, expected)
569+
570+
571+
def test_isin(dtype, request):
572+
s = pd.Series(["a", "b", None], dtype=dtype)
573+
574+
result = s.isin(["a", "c"])
575+
expected = pd.Series([True, False, False])
576+
tm.assert_series_equal(result, expected)
577+
578+
result = s.isin(["a", pd.NA])
579+
expected = pd.Series([True, False, True])
580+
tm.assert_series_equal(result, expected)
581+
582+
result = s.isin([])
583+
expected = pd.Series([False, False, False])
584+
tm.assert_series_equal(result, expected)
585+
586+
result = s.isin(["a", pd.Timestamp.now()])
587+
expected = pd.Series([True, False, False])
588+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)