@@ -315,6 +315,8 @@ class BaseStringArray(ExtensionArray):
315
315
Mixin class for StringArray, ArrowStringArray.
316
316
"""
317
317
318
+ dtype : StringDtype
319
+
318
320
@doc (ExtensionArray .tolist )
319
321
def tolist (self ):
320
322
if self .ndim > 1 :
@@ -328,6 +330,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
328
330
raise ValueError
329
331
return cls ._from_sequence (scalars , dtype = dtype )
330
332
333
+ def _str_map_str_or_object (
334
+ self ,
335
+ dtype ,
336
+ na_value ,
337
+ arr : np .ndarray ,
338
+ f ,
339
+ mask : npt .NDArray [np .bool_ ],
340
+ convert : bool ,
341
+ ):
342
+ # _str_map helper for case where dtype is either string dtype or object
343
+ if is_string_dtype (dtype ) and not is_object_dtype (dtype ):
344
+ # i.e. StringDtype
345
+ result = lib .map_infer_mask (
346
+ arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
347
+ )
348
+ if self .dtype .storage == "pyarrow" :
349
+ import pyarrow as pa
350
+
351
+ result = pa .array (
352
+ result , mask = mask , type = pa .large_string (), from_pandas = True
353
+ )
354
+ # error: Too many arguments for "BaseStringArray"
355
+ return type (self )(result ) # type: ignore[call-arg]
356
+
357
+ else :
358
+ # This is when the result type is object. We reach this when
359
+ # -> We know the result type is truly object (e.g. .encode returns bytes
360
+ # or .findall returns a list).
361
+ # -> We don't know the result type. E.g. `.get` can return anything.
362
+ return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
363
+
331
364
332
365
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
333
366
# incompatible with definition in base class "ExtensionArray"
@@ -682,9 +715,53 @@ def _cmp_method(self, other, op):
682
715
# base class "NumpyExtensionArray" defined the type as "float")
683
716
_str_na_value = libmissing .NA # type: ignore[assignment]
684
717
718
+ def _str_map_nan_semantics (
719
+ self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
720
+ ):
721
+ if dtype is None :
722
+ dtype = self .dtype
723
+ if na_value is None :
724
+ na_value = self .dtype .na_value
725
+
726
+ mask = isna (self )
727
+ arr = np .asarray (self )
728
+ convert = convert and not np .all (mask )
729
+
730
+ if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
731
+ na_value_is_na = isna (na_value )
732
+ if na_value_is_na :
733
+ if is_integer_dtype (dtype ):
734
+ na_value = 0
735
+ else :
736
+ na_value = True
737
+
738
+ result = lib .map_infer_mask (
739
+ arr ,
740
+ f ,
741
+ mask .view ("uint8" ),
742
+ convert = False ,
743
+ na_value = na_value ,
744
+ dtype = np .dtype (cast (type , dtype )),
745
+ )
746
+ if na_value_is_na and mask .any ():
747
+ if is_integer_dtype (dtype ):
748
+ result = result .astype ("float64" )
749
+ else :
750
+ result = result .astype ("object" )
751
+ result [mask ] = np .nan
752
+ return result
753
+
754
+ else :
755
+ return self ._str_map_str_or_object (dtype , na_value , arr , f , mask , convert )
756
+
685
757
def _str_map (
686
758
self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
687
759
):
760
+ if self .dtype .na_value is np .nan :
761
+ return self ._str_map_nan_semantics (
762
+ f , na_value = na_value , dtype = dtype , convert = convert
763
+ )
764
+
688
765
from pandas .arrays import BooleanArray
689
766
690
767
if dtype is None :
@@ -724,18 +801,8 @@ def _str_map(
724
801
725
802
return constructor (result , mask )
726
803
727
- elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
728
- # i.e. StringDtype
729
- result = lib .map_infer_mask (
730
- arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
731
- )
732
- return StringArray (result )
733
804
else :
734
- # This is when the result type is object. We reach this when
735
- # -> We know the result type is truly object (e.g. .encode returns bytes
736
- # or .findall returns a list).
737
- # -> We don't know the result type. E.g. `.get` can return anything.
738
- return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
805
+ return self ._str_map_str_or_object (dtype , na_value , arr , f , mask , convert )
739
806
740
807
741
808
class StringArrayNumpySemantics (StringArray ):
@@ -802,52 +869,3 @@ def value_counts(self, dropna: bool = True) -> Series:
802
869
# ------------------------------------------------------------------------
803
870
# String methods interface
804
871
_str_na_value = np .nan
805
-
806
- def _str_map (
807
- self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
808
- ):
809
- if dtype is None :
810
- dtype = self .dtype
811
- if na_value is None :
812
- na_value = self .dtype .na_value
813
-
814
- mask = isna (self )
815
- arr = np .asarray (self )
816
- convert = convert and not np .all (mask )
817
-
818
- if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
819
- na_value_is_na = isna (na_value )
820
- if na_value_is_na :
821
- if is_integer_dtype (dtype ):
822
- na_value = 0
823
- else :
824
- na_value = True
825
-
826
- result = lib .map_infer_mask (
827
- arr ,
828
- f ,
829
- mask .view ("uint8" ),
830
- convert = False ,
831
- na_value = na_value ,
832
- dtype = np .dtype (cast (type , dtype )),
833
- )
834
- if na_value_is_na and mask .any ():
835
- if is_integer_dtype (dtype ):
836
- result = result .astype ("float64" )
837
- else :
838
- result = result .astype ("object" )
839
- result [mask ] = np .nan
840
- return result
841
-
842
- elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
843
- # i.e. StringDtype
844
- result = lib .map_infer_mask (
845
- arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
846
- )
847
- return type (self )(result )
848
- else :
849
- # This is when the result type is object. We reach this when
850
- # -> We know the result type is truly object (e.g. .encode returns bytes
851
- # or .findall returns a list).
852
- # -> We don't know the result type. E.g. `.get` can return anything.
853
- return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
0 commit comments