|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from functools import partial |
4 | 3 | import operator
|
5 | 4 | import re
|
6 | 5 | from typing import (
|
@@ -216,12 +215,17 @@ def dtype(self) -> StringDtype: # type: ignore[override]
|
216 | 215 | return self._dtype
|
217 | 216 |
|
218 | 217 | def insert(self, loc: int, item) -> ArrowStringArray:
|
| 218 | + if self.dtype.na_value is np.nan and item is np.nan: |
| 219 | + item = libmissing.NA |
219 | 220 | if not isinstance(item, str) and item is not libmissing.NA:
|
220 | 221 | raise TypeError("Scalar must be NA or str")
|
221 | 222 | return super().insert(loc, item)
|
222 | 223 |
|
223 |
| - @classmethod |
224 |
| - def _result_converter(cls, values, na=None): |
| 224 | + def _result_converter(self, values, na=None): |
| 225 | + if self.dtype.na_value is np.nan: |
| 226 | + if not isna(na): |
| 227 | + values = values.fill_null(bool(na)) |
| 228 | + return ArrowExtensionArray(values).to_numpy(na_value=np.nan) |
225 | 229 | return BooleanDtype().__from_arrow__(values)
|
226 | 230 |
|
227 | 231 | def _maybe_convert_setitem_value(self, value):
|
@@ -492,11 +496,30 @@ def _str_get_dummies(self, sep: str = "|"):
|
492 | 496 | return dummies.astype(np.int64, copy=False), labels
|
493 | 497 |
|
494 | 498 | def _convert_int_dtype(self, result):
|
| 499 | + if self.dtype.na_value is np.nan: |
| 500 | + if isinstance(result, pa.Array): |
| 501 | + result = result.to_numpy(zero_copy_only=False) |
| 502 | + else: |
| 503 | + result = result.to_numpy() |
| 504 | + if result.dtype == np.int32: |
| 505 | + result = result.astype(np.int64) |
| 506 | + return result |
| 507 | + |
495 | 508 | return Int64Dtype().__from_arrow__(result)
|
496 | 509 |
|
497 | 510 | def _reduce(
|
498 | 511 | self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
|
499 | 512 | ):
|
| 513 | + if self.dtype.na_value is np.nan and name in ["any", "all"]: |
| 514 | + if not skipna: |
| 515 | + nas = pc.is_null(self._pa_array) |
| 516 | + arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, "")) |
| 517 | + else: |
| 518 | + arr = pc.not_equal(self._pa_array, "") |
| 519 | + return ArrowExtensionArray(arr)._reduce( |
| 520 | + name, skipna=skipna, keepdims=keepdims, **kwargs |
| 521 | + ) |
| 522 | + |
500 | 523 | result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
|
501 | 524 | if name in ("argmin", "argmax") and isinstance(result, pa.Array):
|
502 | 525 | return self._convert_int_dtype(result)
|
@@ -527,67 +550,31 @@ def _rank(
|
527 | 550 | )
|
528 | 551 | )
|
529 | 552 |
|
530 |
| - |
531 |
| -class ArrowStringArrayNumpySemantics(ArrowStringArray): |
532 |
| - _storage = "pyarrow" |
533 |
| - _na_value = np.nan |
534 |
| - |
535 |
| - @classmethod |
536 |
| - def _result_converter(cls, values, na=None): |
537 |
| - if not isna(na): |
538 |
| - values = values.fill_null(bool(na)) |
539 |
| - return ArrowExtensionArray(values).to_numpy(na_value=np.nan) |
540 |
| - |
541 |
| - def __getattribute__(self, item): |
542 |
| - # ArrowStringArray and we both inherit from ArrowExtensionArray, which |
543 |
| - # creates inheritance problems (Diamond inheritance) |
544 |
| - if item in ArrowStringArrayMixin.__dict__ and item not in ( |
545 |
| - "_pa_array", |
546 |
| - "__dict__", |
547 |
| - ): |
548 |
| - return partial(getattr(ArrowStringArrayMixin, item), self) |
549 |
| - return super().__getattribute__(item) |
550 |
| - |
551 |
| - def _convert_int_dtype(self, result): |
552 |
| - if isinstance(result, pa.Array): |
553 |
| - result = result.to_numpy(zero_copy_only=False) |
554 |
| - else: |
555 |
| - result = result.to_numpy() |
556 |
| - if result.dtype == np.int32: |
557 |
| - result = result.astype(np.int64) |
| 553 | + def value_counts(self, dropna: bool = True) -> Series: |
| 554 | + result = super().value_counts(dropna=dropna) |
| 555 | + if self.dtype.na_value is np.nan: |
| 556 | + res_values = result._values.to_numpy() |
| 557 | + return result._constructor( |
| 558 | + res_values, index=result.index, name=result.name, copy=False |
| 559 | + ) |
558 | 560 | return result
|
559 | 561 |
|
560 | 562 | def _cmp_method(self, other, op):
|
561 | 563 | result = super()._cmp_method(other, op)
|
562 |
| - if op == operator.ne: |
563 |
| - return result.to_numpy(np.bool_, na_value=True) |
564 |
| - else: |
565 |
| - return result.to_numpy(np.bool_, na_value=False) |
566 |
| - |
567 |
| - def value_counts(self, dropna: bool = True) -> Series: |
568 |
| - from pandas import Series |
569 |
| - |
570 |
| - result = super().value_counts(dropna) |
571 |
| - return Series( |
572 |
| - result._values.to_numpy(), index=result.index, name=result.name, copy=False |
573 |
| - ) |
574 |
| - |
575 |
| - def _reduce( |
576 |
| - self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs |
577 |
| - ): |
578 |
| - if name in ["any", "all"]: |
579 |
| - if not skipna: |
580 |
| - nas = pc.is_null(self._pa_array) |
581 |
| - arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, "")) |
| 564 | + if self.dtype.na_value is np.nan: |
| 565 | + if op == operator.ne: |
| 566 | + return result.to_numpy(np.bool_, na_value=True) |
582 | 567 | else:
|
583 |
| - arr = pc.not_equal(self._pa_array, "") |
584 |
| - return ArrowExtensionArray(arr)._reduce( |
585 |
| - name, skipna=skipna, keepdims=keepdims, **kwargs |
586 |
| - ) |
587 |
| - else: |
588 |
| - return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs) |
| 568 | + return result.to_numpy(np.bool_, na_value=False) |
| 569 | + return result |
589 | 570 |
|
590 |
| - def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics: |
591 |
| - if item is np.nan: |
592 |
| - item = libmissing.NA |
593 |
| - return super().insert(loc, item) # type: ignore[return-value] |
| 571 | + |
| 572 | +class ArrowStringArrayNumpySemantics(ArrowStringArray): |
| 573 | + _na_value = np.nan |
| 574 | + _str_get = ArrowStringArrayMixin._str_get |
| 575 | + _str_removesuffix = ArrowStringArrayMixin._str_removesuffix |
| 576 | + _str_capitalize = ArrowStringArrayMixin._str_capitalize |
| 577 | + _str_pad = ArrowStringArrayMixin._str_pad |
| 578 | + _str_title = ArrowStringArrayMixin._str_title |
| 579 | + _str_swapcase = ArrowStringArrayMixin._str_swapcase |
| 580 | + _str_slice_replace = ArrowStringArrayMixin._str_slice_replace |
0 commit comments