Skip to content

Commit 8f04a8e

Browse files
authored
ENH/TST: Add isin, _hasna for ArrowExtensionArray (#47805)
1 parent 433dcd5 commit 8f04a8e

File tree

1 file changed

+48
-2
lines changed

1 file changed

+48
-2
lines changed

pandas/core/arrays/arrow/array.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pandas.compat import (
1919
pa_version_under1p01,
2020
pa_version_under2p0,
21+
pa_version_under3p0,
2122
pa_version_under4p0,
2223
pa_version_under5p0,
2324
pa_version_under6p0,
@@ -402,6 +403,10 @@ def __len__(self) -> int:
402403
"""
403404
return len(self._data)
404405

406+
@property
407+
def _hasna(self) -> bool:
408+
return self._data.null_count > 0
409+
405410
def isna(self) -> npt.NDArray[np.bool_]:
406411
"""
407412
Boolean NumPy array indicating if each value is missing.
@@ -439,6 +444,49 @@ def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
439444
else:
440445
return type(self)(pc.drop_null(self._data))
441446

447+
def isin(self: ArrowExtensionArrayT, values) -> npt.NDArray[np.bool_]:
448+
if pa_version_under2p0:
449+
fallback_performancewarning(version="2")
450+
return super().isin(values)
451+
452+
# for an empty value_set pyarrow 3.0.0 segfaults and pyarrow 2.0.0 returns True
453+
# for null values, so we short-circuit to return all False array.
454+
if not len(values):
455+
return np.zeros(len(self), dtype=bool)
456+
457+
kwargs = {}
458+
if pa_version_under3p0:
459+
# in pyarrow 2.0.0 skip_null is ignored but is a required keyword and raises
460+
# with unexpected keyword argument in pyarrow 3.0.0+
461+
kwargs["skip_null"] = True
462+
463+
result = pc.is_in(
464+
self._data, value_set=pa.array(values, from_pandas=True), **kwargs
465+
)
466+
# pyarrow 2.0.0 returned nulls, so we explicitly specify dtype to convert nulls
467+
# to False
468+
return np.array(result, dtype=np.bool_)
469+
470+
def _values_for_factorize(self) -> tuple[np.ndarray, Any]:
471+
"""
472+
Return an array and missing value suitable for factorization.
473+
474+
Returns
475+
-------
476+
values : ndarray
477+
na_value : pd.NA
478+
479+
Notes
480+
-----
481+
The values returned by this method are also used in
482+
:func:`pandas.util.hash_pandas_object`.
483+
"""
484+
if pa_version_under2p0:
485+
values = self._data.to_pandas().values
486+
else:
487+
values = self._data.to_numpy()
488+
return values, self.dtype.na_value
489+
442490
@doc(ExtensionArray.factorize)
443491
def factorize(
444492
self,
@@ -636,8 +684,6 @@ def _concat_same_type(
636684
-------
637685
ArrowExtensionArray
638686
"""
639-
import pyarrow as pa
640-
641687
chunks = [array for ea in to_concat for array in ea._data.iterchunks()]
642688
arr = pa.chunked_array(chunks)
643689
return cls(arr)

0 commit comments

Comments
 (0)