|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from functools import partial |
4 | 3 | import operator
|
5 | 4 | import re
|
6 | 5 | from typing import (
|
@@ -209,12 +208,17 @@ def dtype(self) -> StringDtype: # type: ignore[override]
|
209 | 208 | return self._dtype
|
210 | 209 |
|
211 | 210 | def insert(self, loc: int, item) -> ArrowStringArray:
|
| 211 | + if self.dtype.na_value is np.nan and item is np.nan: |
| 212 | + item = libmissing.NA |
212 | 213 | if not isinstance(item, str) and item is not libmissing.NA:
|
213 | 214 | raise TypeError("Scalar must be NA or str")
|
214 | 215 | return super().insert(loc, item)
|
215 | 216 |
|
216 |
| - @classmethod |
217 |
| - def _result_converter(cls, values, na=None): |
| 217 | + def _result_converter(self, values, na=None): |
| 218 | + if self.dtype.na_value is np.nan: |
| 219 | + if not isna(na): |
| 220 | + values = values.fill_null(bool(na)) |
| 221 | + return ArrowExtensionArray(values).to_numpy(na_value=np.nan) |
218 | 222 | return BooleanDtype().__from_arrow__(values)
|
219 | 223 |
|
220 | 224 | def _maybe_convert_setitem_value(self, value):
|
@@ -494,11 +498,30 @@ def _str_get_dummies(self, sep: str = "|"):
|
494 | 498 | return dummies.astype(np.int64, copy=False), labels
|
495 | 499 |
|
496 | 500 | def _convert_int_dtype(self, result):
|
| 501 | + if self.dtype.na_value is np.nan: |
| 502 | + if isinstance(result, pa.Array): |
| 503 | + result = result.to_numpy(zero_copy_only=False) |
| 504 | + else: |
| 505 | + result = result.to_numpy() |
| 506 | + if result.dtype == np.int32: |
| 507 | + result = result.astype(np.int64) |
| 508 | + return result |
| 509 | + |
497 | 510 | return Int64Dtype().__from_arrow__(result)
|
498 | 511 |
|
499 | 512 | def _reduce(
|
500 | 513 | self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
|
501 | 514 | ):
|
| 515 | + if self.dtype.na_value is np.nan and name in ["any", "all"]: |
| 516 | + if not skipna: |
| 517 | + nas = pc.is_null(self._pa_array) |
| 518 | + arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, "")) |
| 519 | + else: |
| 520 | + arr = pc.not_equal(self._pa_array, "") |
| 521 | + return ArrowExtensionArray(arr)._reduce( |
| 522 | + name, skipna=skipna, keepdims=keepdims, **kwargs |
| 523 | + ) |
| 524 | + |
502 | 525 | result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
|
503 | 526 | if name in ("argmin", "argmax") and isinstance(result, pa.Array):
|
504 | 527 | return self._convert_int_dtype(result)
|
@@ -529,67 +552,31 @@ def _rank(
|
529 | 552 | )
|
530 | 553 | )
|
531 | 554 |
|
532 |
| - |
533 |
| -class ArrowStringArrayNumpySemantics(ArrowStringArray): |
534 |
| - _storage = "pyarrow" |
535 |
| - _na_value = np.nan |
536 |
| - |
537 |
| - @classmethod |
538 |
| - def _result_converter(cls, values, na=None): |
539 |
| - if not isna(na): |
540 |
| - values = values.fill_null(bool(na)) |
541 |
| - return ArrowExtensionArray(values).to_numpy(na_value=np.nan) |
542 |
| - |
543 |
| - def __getattribute__(self, item): |
544 |
| - # ArrowStringArray and we both inherit from ArrowExtensionArray, which |
545 |
| - # creates inheritance problems (Diamond inheritance) |
546 |
| - if item in ArrowStringArrayMixin.__dict__ and item not in ( |
547 |
| - "_pa_array", |
548 |
| - "__dict__", |
549 |
| - ): |
550 |
| - return partial(getattr(ArrowStringArrayMixin, item), self) |
551 |
| - return super().__getattribute__(item) |
552 |
| - |
553 |
| - def _convert_int_dtype(self, result): |
554 |
| - if isinstance(result, pa.Array): |
555 |
| - result = result.to_numpy(zero_copy_only=False) |
556 |
| - else: |
557 |
| - result = result.to_numpy() |
558 |
| - if result.dtype == np.int32: |
559 |
| - result = result.astype(np.int64) |
| 555 | + def value_counts(self, dropna: bool = True) -> Series: |
| 556 | + result = super().value_counts(dropna=dropna) |
| 557 | + if self.dtype.na_value is np.nan: |
| 558 | + res_values = result._values.to_numpy() |
| 559 | + return result._constructor( |
| 560 | + res_values, index=result.index, name=result.name, copy=False |
| 561 | + ) |
560 | 562 | return result
|
561 | 563 |
|
562 | 564 | def _cmp_method(self, other, op):
|
563 | 565 | result = super()._cmp_method(other, op)
|
564 |
| - if op == operator.ne: |
565 |
| - return result.to_numpy(np.bool_, na_value=True) |
566 |
| - else: |
567 |
| - return result.to_numpy(np.bool_, na_value=False) |
568 |
| - |
569 |
| - def value_counts(self, dropna: bool = True) -> Series: |
570 |
| - from pandas import Series |
571 |
| - |
572 |
| - result = super().value_counts(dropna) |
573 |
| - return Series( |
574 |
| - result._values.to_numpy(), index=result.index, name=result.name, copy=False |
575 |
| - ) |
576 |
| - |
577 |
| - def _reduce( |
578 |
| - self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs |
579 |
| - ): |
580 |
| - if name in ["any", "all"]: |
581 |
| - if not skipna and name == "all": |
582 |
| - nas = pc.invert(pc.is_null(self._pa_array)) |
583 |
| - arr = pc.and_kleene(nas, pc.not_equal(self._pa_array, "")) |
| 566 | + if self.dtype.na_value is np.nan: |
| 567 | + if op == operator.ne: |
| 568 | + return result.to_numpy(np.bool_, na_value=True) |
584 | 569 | else:
|
585 |
| - arr = pc.not_equal(self._pa_array, "") |
586 |
| - return ArrowExtensionArray(arr)._reduce( |
587 |
| - name, skipna=skipna, keepdims=keepdims, **kwargs |
588 |
| - ) |
589 |
| - else: |
590 |
| - return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs) |
| 570 | + return result.to_numpy(np.bool_, na_value=False) |
| 571 | + return result |
591 | 572 |
|
592 |
| - def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics: |
593 |
| - if item is np.nan: |
594 |
| - item = libmissing.NA |
595 |
| - return super().insert(loc, item) # type: ignore[return-value] |
| 573 | + |
| 574 | +class ArrowStringArrayNumpySemantics(ArrowStringArray): |
| 575 | + _na_value = np.nan |
| 576 | + _str_get = ArrowStringArrayMixin._str_get |
| 577 | + _str_removesuffix = ArrowStringArrayMixin._str_removesuffix |
| 578 | + _str_capitalize = ArrowStringArrayMixin._str_capitalize |
| 579 | + _str_pad = ArrowStringArrayMixin._str_pad |
| 580 | + _str_title = ArrowStringArrayMixin._str_title |
| 581 | + _str_swapcase = ArrowStringArrayMixin._str_swapcase |
| 582 | + _str_slice_replace = ArrowStringArrayMixin._str_slice_replace |
0 commit comments