Skip to content

Commit 03a8009

Browse files
committed
REF: Move value_counts, isin to ArrowExtensionArray
1 parent 6033ed4 commit 03a8009

File tree

2 files changed

+74
-73
lines changed

2 files changed

+74
-73
lines changed

pandas/core/arrays/_mixins.py

+68
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pandas.compat import (
3333
pa_version_under1p01,
3434
pa_version_under2p0,
35+
pa_version_under3p0,
3536
pa_version_under5p0,
3637
)
3738
from pandas.errors import AbstractMethodError
@@ -86,6 +87,8 @@
8687
NumpyValueArrayLike,
8788
)
8889

90+
from pandas import Series
91+
8992

9093
def ravel_compat(meth: F) -> F:
9194
"""
@@ -544,6 +547,7 @@ class ArrowExtensionArray(ExtensionArray):
544547
"""
545548

546549
_data: pa.ChunkedArray
550+
_pa_dtype: pa.DataType()
547551

548552
def __init__(self, values: pa.ChunkedArray) -> None:
549553
self._data = values
@@ -599,6 +603,70 @@ def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
599603
"""
600604
return type(self)(self._data)
601605

606+
def isin(self, values):
607+
if pa_version_under2p0:
608+
return super().isin(values)
609+
610+
value_set = [
611+
pa_scalar.as_py()
612+
for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values]
613+
if pa_scalar.type in (self._pa_dtype, pa.null())
614+
]
615+
616+
# for an empty value_set pyarrow 3.0.0 segfaults and pyarrow 2.0.0 returns True
617+
# for null values, so we short-circuit to return all False array.
618+
if not len(value_set):
619+
return np.zeros(len(self), dtype=bool)
620+
621+
kwargs = {}
622+
if pa_version_under3p0:
623+
# in pyarrow 2.0.0 skip_null is ignored but is a required keyword and raises
624+
# with unexpected keyword argument in pyarrow 3.0.0+
625+
kwargs["skip_null"] = True
626+
627+
result = pc.is_in(self._data, value_set=pa.array(value_set), **kwargs)
628+
# pyarrow 2.0.0 returned nulls, so we explicily specify dtype to convert nulls
629+
# to False
630+
return np.array(result, dtype=np.bool_)
631+
632+
def value_counts(self, dropna: bool = True) -> Series:
633+
"""
634+
Return a Series containing counts of each unique value.
635+
636+
Parameters
637+
----------
638+
dropna : bool, default True
639+
Don't include counts of missing values.
640+
641+
Returns
642+
-------
643+
counts : Series
644+
645+
See Also
646+
--------
647+
Series.value_counts
648+
"""
649+
from pandas import (
650+
Index,
651+
Series,
652+
)
653+
654+
vc = self._data.value_counts()
655+
656+
values = vc.field(0)
657+
counts = vc.field(1)
658+
if dropna and self._data.null_count > 0:
659+
mask = values.is_valid()
660+
values = values.filter(mask)
661+
counts = counts.filter(mask)
662+
663+
# No missing values so we can adhere to the interface and return a numpy array.
664+
counts = np.array(counts)
665+
666+
index = Index(type(self)(values))
667+
668+
return Series(counts, index=index).astype("Int64")
669+
602670
@classmethod
603671
def _concat_same_type(
604672
cls: type[ArrowExtensionArrayT], to_concat

pandas/core/arrays/string_arrow.py

+6-73
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections.abc import Callable # noqa: PDF001
44
import re
55
from typing import (
6-
TYPE_CHECKING,
76
Any,
87
Union,
98
overload,
@@ -28,7 +27,6 @@
2827
from pandas.compat import (
2928
pa_version_under1p01,
3029
pa_version_under2p0,
31-
pa_version_under3p0,
3230
pa_version_under4p0,
3331
)
3432
from pandas.util._decorators import doc
@@ -77,9 +75,6 @@
7775
}
7876

7977

80-
if TYPE_CHECKING:
81-
from pandas import Series
82-
8378
ArrowStringScalarOrNAT = Union[str, libmissing.NAType]
8479

8580

@@ -140,6 +135,8 @@ class ArrowStringArray(
140135
Length: 4, dtype: string
141136
"""
142137

138+
_pa_dtype = pa.string()
139+
143140
def __init__(self, values) -> None:
144141
self._dtype = StringDtype(storage="pyarrow")
145142
if isinstance(values, pa.Array):
@@ -170,11 +167,11 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False)
170167
na_values = scalars._mask
171168
result = scalars._data
172169
result = lib.ensure_string_array(result, copy=copy, convert_na_value=False)
173-
return cls(pa.array(result, mask=na_values, type=pa.string()))
170+
return cls(pa.array(result, mask=na_values, type=cls._pa_dtype))
174171

175172
# convert non-na-likes to str
176173
result = lib.ensure_string_array(scalars, copy=copy)
177-
return cls(pa.array(result, type=pa.string(), from_pandas=True))
174+
return cls(pa.array(result, type=cls._pa_dtype, from_pandas=True))
178175

179176
@classmethod
180177
def _from_sequence_of_strings(
@@ -269,7 +266,7 @@ def __getitem__(
269266

270267
if isinstance(item, np.ndarray):
271268
if not len(item):
272-
return type(self)(pa.chunked_array([], type=pa.string()))
269+
return type(self)(pa.chunked_array([], type=self._pa_dtype))
273270
elif is_integer_dtype(item.dtype):
274271
return self.take(item)
275272
elif is_bool_dtype(item.dtype):
@@ -455,70 +452,6 @@ def take(
455452
indices_array[indices_array < 0] += len(self._data)
456453
return type(self)(self._data.take(indices_array))
457454

458-
def isin(self, values):
459-
if pa_version_under2p0:
460-
return super().isin(values)
461-
462-
value_set = [
463-
pa_scalar.as_py()
464-
for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values]
465-
if pa_scalar.type in (pa.string(), pa.null())
466-
]
467-
468-
# for an empty value_set pyarrow 3.0.0 segfaults and pyarrow 2.0.0 returns True
469-
# for null values, so we short-circuit to return all False array.
470-
if not len(value_set):
471-
return np.zeros(len(self), dtype=bool)
472-
473-
kwargs = {}
474-
if pa_version_under3p0:
475-
# in pyarrow 2.0.0 skip_null is ignored but is a required keyword and raises
476-
# with unexpected keyword argument in pyarrow 3.0.0+
477-
kwargs["skip_null"] = True
478-
479-
result = pc.is_in(self._data, value_set=pa.array(value_set), **kwargs)
480-
# pyarrow 2.0.0 returned nulls, so we explicily specify dtype to convert nulls
481-
# to False
482-
return np.array(result, dtype=np.bool_)
483-
484-
def value_counts(self, dropna: bool = True) -> Series:
485-
"""
486-
Return a Series containing counts of each unique value.
487-
488-
Parameters
489-
----------
490-
dropna : bool, default True
491-
Don't include counts of missing values.
492-
493-
Returns
494-
-------
495-
counts : Series
496-
497-
See Also
498-
--------
499-
Series.value_counts
500-
"""
501-
from pandas import (
502-
Index,
503-
Series,
504-
)
505-
506-
vc = self._data.value_counts()
507-
508-
values = vc.field(0)
509-
counts = vc.field(1)
510-
if dropna and self._data.null_count > 0:
511-
mask = values.is_valid()
512-
values = values.filter(mask)
513-
counts = counts.filter(mask)
514-
515-
# No missing values so we can adhere to the interface and return a numpy array.
516-
counts = np.array(counts)
517-
518-
index = Index(type(self)(values))
519-
520-
return Series(counts, index=index).astype("Int64")
521-
522455
def astype(self, dtype, copy: bool = True):
523456
dtype = pandas_dtype(dtype)
524457

@@ -590,7 +523,7 @@ def _str_map(
590523
result = lib.map_infer_mask(
591524
arr, f, mask.view("uint8"), convert=False, na_value=na_value
592525
)
593-
result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True)
526+
result = pa.array(result, mask=mask, type=self._pa_dtype, from_pandas=True)
594527
return type(self)(result)
595528
else:
596529
# This is when the result type is object. We reach this when

0 commit comments

Comments
 (0)