Skip to content

Commit d47cca7

Browse files
jbrockmendelWillAyd
authored andcommitted
REF (string): move ArrowStringArrayNumpySemantics methods to base class (pandas-dev#59501)
* REF: move ArrowStringArrayNumpySemantics methods to parent class * REF: move methods to ArrowStringArray * mypy fixup * Fix incorrect double-unpacking * move methods to subclass
1 parent 4c95cba commit d47cca7

File tree

1 file changed

+48
-61
lines changed

1 file changed

+48
-61
lines changed

pandas/core/arrays/string_arrow.py

+48-61
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from functools import partial
43
import operator
54
import re
65
from typing import (
@@ -209,12 +208,17 @@ def dtype(self) -> StringDtype: # type: ignore[override]
209208
return self._dtype
210209

211210
def insert(self, loc: int, item) -> ArrowStringArray:
211+
if self.dtype.na_value is np.nan and item is np.nan:
212+
item = libmissing.NA
212213
if not isinstance(item, str) and item is not libmissing.NA:
213214
raise TypeError("Scalar must be NA or str")
214215
return super().insert(loc, item)
215216

216-
@classmethod
217-
def _result_converter(cls, values, na=None):
217+
def _result_converter(self, values, na=None):
218+
if self.dtype.na_value is np.nan:
219+
if not isna(na):
220+
values = values.fill_null(bool(na))
221+
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
218222
return BooleanDtype().__from_arrow__(values)
219223

220224
def _maybe_convert_setitem_value(self, value):
@@ -494,11 +498,30 @@ def _str_get_dummies(self, sep: str = "|"):
494498
return dummies.astype(np.int64, copy=False), labels
495499

496500
def _convert_int_dtype(self, result):
501+
if self.dtype.na_value is np.nan:
502+
if isinstance(result, pa.Array):
503+
result = result.to_numpy(zero_copy_only=False)
504+
else:
505+
result = result.to_numpy()
506+
if result.dtype == np.int32:
507+
result = result.astype(np.int64)
508+
return result
509+
497510
return Int64Dtype().__from_arrow__(result)
498511

499512
def _reduce(
500513
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
501514
):
515+
if self.dtype.na_value is np.nan and name in ["any", "all"]:
516+
if not skipna:
517+
nas = pc.is_null(self._pa_array)
518+
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
519+
else:
520+
arr = pc.not_equal(self._pa_array, "")
521+
return ArrowExtensionArray(arr)._reduce(
522+
name, skipna=skipna, keepdims=keepdims, **kwargs
523+
)
524+
502525
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
503526
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
504527
return self._convert_int_dtype(result)
@@ -529,67 +552,31 @@ def _rank(
529552
)
530553
)
531554

532-
533-
class ArrowStringArrayNumpySemantics(ArrowStringArray):
534-
_storage = "pyarrow"
535-
_na_value = np.nan
536-
537-
@classmethod
538-
def _result_converter(cls, values, na=None):
539-
if not isna(na):
540-
values = values.fill_null(bool(na))
541-
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
542-
543-
def __getattribute__(self, item):
544-
# ArrowStringArray and we both inherit from ArrowExtensionArray, which
545-
# creates inheritance problems (Diamond inheritance)
546-
if item in ArrowStringArrayMixin.__dict__ and item not in (
547-
"_pa_array",
548-
"__dict__",
549-
):
550-
return partial(getattr(ArrowStringArrayMixin, item), self)
551-
return super().__getattribute__(item)
552-
553-
def _convert_int_dtype(self, result):
554-
if isinstance(result, pa.Array):
555-
result = result.to_numpy(zero_copy_only=False)
556-
else:
557-
result = result.to_numpy()
558-
if result.dtype == np.int32:
559-
result = result.astype(np.int64)
555+
def value_counts(self, dropna: bool = True) -> Series:
556+
result = super().value_counts(dropna=dropna)
557+
if self.dtype.na_value is np.nan:
558+
res_values = result._values.to_numpy()
559+
return result._constructor(
560+
res_values, index=result.index, name=result.name, copy=False
561+
)
560562
return result
561563

562564
def _cmp_method(self, other, op):
563565
result = super()._cmp_method(other, op)
564-
if op == operator.ne:
565-
return result.to_numpy(np.bool_, na_value=True)
566-
else:
567-
return result.to_numpy(np.bool_, na_value=False)
568-
569-
def value_counts(self, dropna: bool = True) -> Series:
570-
from pandas import Series
571-
572-
result = super().value_counts(dropna)
573-
return Series(
574-
result._values.to_numpy(), index=result.index, name=result.name, copy=False
575-
)
576-
577-
def _reduce(
578-
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
579-
):
580-
if name in ["any", "all"]:
581-
if not skipna and name == "all":
582-
nas = pc.invert(pc.is_null(self._pa_array))
583-
arr = pc.and_kleene(nas, pc.not_equal(self._pa_array, ""))
566+
if self.dtype.na_value is np.nan:
567+
if op == operator.ne:
568+
return result.to_numpy(np.bool_, na_value=True)
584569
else:
585-
arr = pc.not_equal(self._pa_array, "")
586-
return ArrowExtensionArray(arr)._reduce(
587-
name, skipna=skipna, keepdims=keepdims, **kwargs
588-
)
589-
else:
590-
return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
570+
return result.to_numpy(np.bool_, na_value=False)
571+
return result
591572

592-
def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics:
593-
if item is np.nan:
594-
item = libmissing.NA
595-
return super().insert(loc, item) # type: ignore[return-value]
573+
574+
class ArrowStringArrayNumpySemantics(ArrowStringArray):
575+
_na_value = np.nan
576+
_str_get = ArrowStringArrayMixin._str_get
577+
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
578+
_str_capitalize = ArrowStringArrayMixin._str_capitalize
579+
_str_pad = ArrowStringArrayMixin._str_pad
580+
_str_title = ArrowStringArrayMixin._str_title
581+
_str_swapcase = ArrowStringArrayMixin._str_swapcase
582+
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace

0 commit comments

Comments
 (0)