Skip to content

Commit 8bece71

Browse files
authored
REF (string): move ArrowStringArrayNumpySemantics methods to base class (#59501)
* REF: move ArrowStringArrayNumpySemantics methods to parent class * REF: move methods to ArrowStringArray * mypy fixup * Fix incorrect double-unpacking * move methods to subclass
1 parent 66e465e commit 8bece71

File tree

1 file changed

+48
-61
lines changed

1 file changed

+48
-61
lines changed

pandas/core/arrays/string_arrow.py

+48-61
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from functools import partial
43
import operator
54
import re
65
from typing import (
@@ -216,12 +215,17 @@ def dtype(self) -> StringDtype: # type: ignore[override]
216215
return self._dtype
217216

218217
def insert(self, loc: int, item) -> ArrowStringArray:
218+
if self.dtype.na_value is np.nan and item is np.nan:
219+
item = libmissing.NA
219220
if not isinstance(item, str) and item is not libmissing.NA:
220221
raise TypeError("Scalar must be NA or str")
221222
return super().insert(loc, item)
222223

223-
@classmethod
224-
def _result_converter(cls, values, na=None):
224+
def _result_converter(self, values, na=None):
225+
if self.dtype.na_value is np.nan:
226+
if not isna(na):
227+
values = values.fill_null(bool(na))
228+
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
225229
return BooleanDtype().__from_arrow__(values)
226230

227231
def _maybe_convert_setitem_value(self, value):
@@ -492,11 +496,30 @@ def _str_get_dummies(self, sep: str = "|"):
492496
return dummies.astype(np.int64, copy=False), labels
493497

494498
def _convert_int_dtype(self, result):
499+
if self.dtype.na_value is np.nan:
500+
if isinstance(result, pa.Array):
501+
result = result.to_numpy(zero_copy_only=False)
502+
else:
503+
result = result.to_numpy()
504+
if result.dtype == np.int32:
505+
result = result.astype(np.int64)
506+
return result
507+
495508
return Int64Dtype().__from_arrow__(result)
496509

497510
def _reduce(
498511
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
499512
):
513+
if self.dtype.na_value is np.nan and name in ["any", "all"]:
514+
if not skipna:
515+
nas = pc.is_null(self._pa_array)
516+
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
517+
else:
518+
arr = pc.not_equal(self._pa_array, "")
519+
return ArrowExtensionArray(arr)._reduce(
520+
name, skipna=skipna, keepdims=keepdims, **kwargs
521+
)
522+
500523
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
501524
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
502525
return self._convert_int_dtype(result)
@@ -527,67 +550,31 @@ def _rank(
527550
)
528551
)
529552

530-
531-
class ArrowStringArrayNumpySemantics(ArrowStringArray):
532-
_storage = "pyarrow"
533-
_na_value = np.nan
534-
535-
@classmethod
536-
def _result_converter(cls, values, na=None):
537-
if not isna(na):
538-
values = values.fill_null(bool(na))
539-
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
540-
541-
def __getattribute__(self, item):
542-
# ArrowStringArray and we both inherit from ArrowExtensionArray, which
543-
# creates inheritance problems (Diamond inheritance)
544-
if item in ArrowStringArrayMixin.__dict__ and item not in (
545-
"_pa_array",
546-
"__dict__",
547-
):
548-
return partial(getattr(ArrowStringArrayMixin, item), self)
549-
return super().__getattribute__(item)
550-
551-
def _convert_int_dtype(self, result):
552-
if isinstance(result, pa.Array):
553-
result = result.to_numpy(zero_copy_only=False)
554-
else:
555-
result = result.to_numpy()
556-
if result.dtype == np.int32:
557-
result = result.astype(np.int64)
553+
def value_counts(self, dropna: bool = True) -> Series:
554+
result = super().value_counts(dropna=dropna)
555+
if self.dtype.na_value is np.nan:
556+
res_values = result._values.to_numpy()
557+
return result._constructor(
558+
res_values, index=result.index, name=result.name, copy=False
559+
)
558560
return result
559561

560562
def _cmp_method(self, other, op):
561563
result = super()._cmp_method(other, op)
562-
if op == operator.ne:
563-
return result.to_numpy(np.bool_, na_value=True)
564-
else:
565-
return result.to_numpy(np.bool_, na_value=False)
566-
567-
def value_counts(self, dropna: bool = True) -> Series:
568-
from pandas import Series
569-
570-
result = super().value_counts(dropna)
571-
return Series(
572-
result._values.to_numpy(), index=result.index, name=result.name, copy=False
573-
)
574-
575-
def _reduce(
576-
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
577-
):
578-
if name in ["any", "all"]:
579-
if not skipna:
580-
nas = pc.is_null(self._pa_array)
581-
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
564+
if self.dtype.na_value is np.nan:
565+
if op == operator.ne:
566+
return result.to_numpy(np.bool_, na_value=True)
582567
else:
583-
arr = pc.not_equal(self._pa_array, "")
584-
return ArrowExtensionArray(arr)._reduce(
585-
name, skipna=skipna, keepdims=keepdims, **kwargs
586-
)
587-
else:
588-
return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
568+
return result.to_numpy(np.bool_, na_value=False)
569+
return result
589570

590-
def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics:
591-
if item is np.nan:
592-
item = libmissing.NA
593-
return super().insert(loc, item) # type: ignore[return-value]
571+
572+
class ArrowStringArrayNumpySemantics(ArrowStringArray):
573+
_na_value = np.nan
574+
_str_get = ArrowStringArrayMixin._str_get
575+
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
576+
_str_capitalize = ArrowStringArrayMixin._str_capitalize
577+
_str_pad = ArrowStringArrayMixin._str_pad
578+
_str_title = ArrowStringArrayMixin._str_title
579+
_str_swapcase = ArrowStringArrayMixin._str_swapcase
580+
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace

0 commit comments

Comments
 (0)