Skip to content

Commit 80e6a51

Browse files
committed
REF (string dtype): de-duplicate _str_map methods (pandas-dev#59443)
* REF: de-duplicate _str_map methods * mypy fixup
1 parent 4fb4478 commit 80e6a51

File tree

2 files changed

+124
-131
lines changed

2 files changed

+124
-131
lines changed

pandas/core/arrays/string_.py

+78-60
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ class BaseStringArray(ExtensionArray):
315315
Mixin class for StringArray, ArrowStringArray.
316316
"""
317317

318+
dtype: StringDtype
319+
318320
@doc(ExtensionArray.tolist)
319321
def tolist(self):
320322
if self.ndim > 1:
@@ -328,6 +330,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
328330
raise ValueError
329331
return cls._from_sequence(scalars, dtype=dtype)
330332

333+
def _str_map_str_or_object(
334+
self,
335+
dtype,
336+
na_value,
337+
arr: np.ndarray,
338+
f,
339+
mask: npt.NDArray[np.bool_],
340+
convert: bool,
341+
):
342+
# _str_map helper for case where dtype is either string dtype or object
343+
if is_string_dtype(dtype) and not is_object_dtype(dtype):
344+
# i.e. StringDtype
345+
result = lib.map_infer_mask(
346+
arr, f, mask.view("uint8"), convert=False, na_value=na_value
347+
)
348+
if self.dtype.storage == "pyarrow":
349+
import pyarrow as pa
350+
351+
result = pa.array(
352+
result, mask=mask, type=pa.large_string(), from_pandas=True
353+
)
354+
# error: Too many arguments for "BaseStringArray"
355+
return type(self)(result) # type: ignore[call-arg]
356+
357+
else:
358+
# This is when the result type is object. We reach this when
359+
# -> We know the result type is truly object (e.g. .encode returns bytes
360+
# or .findall returns a list).
361+
# -> We don't know the result type. E.g. `.get` can return anything.
362+
return lib.map_infer_mask(arr, f, mask.view("uint8"))
363+
331364

332365
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
333366
# incompatible with definition in base class "ExtensionArray"
@@ -682,9 +715,53 @@ def _cmp_method(self, other, op):
682715
# base class "NumpyExtensionArray" defined the type as "float")
683716
_str_na_value = libmissing.NA # type: ignore[assignment]
684717

718+
def _str_map_nan_semantics(
719+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
720+
):
721+
if dtype is None:
722+
dtype = self.dtype
723+
if na_value is None:
724+
na_value = self.dtype.na_value
725+
726+
mask = isna(self)
727+
arr = np.asarray(self)
728+
convert = convert and not np.all(mask)
729+
730+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
731+
na_value_is_na = isna(na_value)
732+
if na_value_is_na:
733+
if is_integer_dtype(dtype):
734+
na_value = 0
735+
else:
736+
na_value = True
737+
738+
result = lib.map_infer_mask(
739+
arr,
740+
f,
741+
mask.view("uint8"),
742+
convert=False,
743+
na_value=na_value,
744+
dtype=np.dtype(cast(type, dtype)),
745+
)
746+
if na_value_is_na and mask.any():
747+
if is_integer_dtype(dtype):
748+
result = result.astype("float64")
749+
else:
750+
result = result.astype("object")
751+
result[mask] = np.nan
752+
return result
753+
754+
else:
755+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
756+
685757
def _str_map(
686758
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
687759
):
760+
if self.dtype.na_value is np.nan:
761+
return self._str_map_nan_semantics(
762+
f, na_value=na_value, dtype=dtype, convert=convert
763+
)
764+
688765
from pandas.arrays import BooleanArray
689766

690767
if dtype is None:
@@ -724,18 +801,8 @@ def _str_map(
724801

725802
return constructor(result, mask)
726803

727-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
728-
# i.e. StringDtype
729-
result = lib.map_infer_mask(
730-
arr, f, mask.view("uint8"), convert=False, na_value=na_value
731-
)
732-
return StringArray(result)
733804
else:
734-
# This is when the result type is object. We reach this when
735-
# -> We know the result type is truly object (e.g. .encode returns bytes
736-
# or .findall returns a list).
737-
# -> We don't know the result type. E.g. `.get` can return anything.
738-
return lib.map_infer_mask(arr, f, mask.view("uint8"))
805+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
739806

740807

741808
class StringArrayNumpySemantics(StringArray):
@@ -802,52 +869,3 @@ def value_counts(self, dropna: bool = True) -> Series:
802869
# ------------------------------------------------------------------------
803870
# String methods interface
804871
_str_na_value = np.nan
805-
806-
def _str_map(
807-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
808-
):
809-
if dtype is None:
810-
dtype = self.dtype
811-
if na_value is None:
812-
na_value = self.dtype.na_value
813-
814-
mask = isna(self)
815-
arr = np.asarray(self)
816-
convert = convert and not np.all(mask)
817-
818-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
819-
na_value_is_na = isna(na_value)
820-
if na_value_is_na:
821-
if is_integer_dtype(dtype):
822-
na_value = 0
823-
else:
824-
na_value = True
825-
826-
result = lib.map_infer_mask(
827-
arr,
828-
f,
829-
mask.view("uint8"),
830-
convert=False,
831-
na_value=na_value,
832-
dtype=np.dtype(cast(type, dtype)),
833-
)
834-
if na_value_is_na and mask.any():
835-
if is_integer_dtype(dtype):
836-
result = result.astype("float64")
837-
else:
838-
result = result.astype("object")
839-
result[mask] = np.nan
840-
return result
841-
842-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
843-
# i.e. StringDtype
844-
result = lib.map_infer_mask(
845-
arr, f, mask.view("uint8"), convert=False, na_value=na_value
846-
)
847-
return type(self)(result)
848-
else:
849-
# This is when the result type is object. We reach this when
850-
# -> We know the result type is truly object (e.g. .encode returns bytes
851-
# or .findall returns a list).
852-
# -> We don't know the result type. E.g. `.get` can return anything.
853-
return lib.map_infer_mask(arr, f, mask.view("uint8"))

pandas/core/arrays/string_arrow.py

+46-71
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
TYPE_CHECKING,
88
Callable,
99
Union,
10+
cast,
1011
)
1112
import warnings
1213

@@ -25,9 +26,7 @@
2526
from pandas.core.dtypes.common import (
2627
is_bool_dtype,
2728
is_integer_dtype,
28-
is_object_dtype,
2929
is_scalar,
30-
is_string_dtype,
3130
pandas_dtype,
3231
)
3332
from pandas.core.dtypes.missing import isna
@@ -284,9 +283,53 @@ def _data(self):
284283
# base class "ObjectStringArrayMixin" defined the type as "float")
285284
_str_na_value = libmissing.NA # type: ignore[assignment]
286285

286+
def _str_map_nan_semantics(
287+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
288+
):
289+
if dtype is None:
290+
dtype = self.dtype
291+
if na_value is None:
292+
na_value = self.dtype.na_value
293+
294+
mask = isna(self)
295+
arr = np.asarray(self)
296+
297+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
298+
if is_integer_dtype(dtype):
299+
na_value = np.nan
300+
else:
301+
na_value = False
302+
303+
dtype = np.dtype(cast(type, dtype))
304+
if mask.any():
305+
# numpy int/bool dtypes cannot hold NaNs so we must convert to
306+
# float64 for int (to match maybe_convert_objects) or
307+
# object for bool (again to match maybe_convert_objects)
308+
if is_integer_dtype(dtype):
309+
dtype = np.dtype("float64")
310+
else:
311+
dtype = np.dtype(object)
312+
result = lib.map_infer_mask(
313+
arr,
314+
f,
315+
mask.view("uint8"),
316+
convert=False,
317+
na_value=na_value,
318+
dtype=dtype,
319+
)
320+
return result
321+
322+
else:
323+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
324+
287325
def _str_map(
288326
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
289327
):
328+
if self.dtype.na_value is np.nan:
329+
return self._str_map_nan_semantics(
330+
f, na_value=na_value, dtype=dtype, convert=convert
331+
)
332+
290333
# TODO: de-duplicate with StringArray method. This method is moreless copy and
291334
# paste.
292335

@@ -330,21 +373,8 @@ def _str_map(
330373

331374
return constructor(result, mask)
332375

333-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
334-
# i.e. StringDtype
335-
result = lib.map_infer_mask(
336-
arr, f, mask.view("uint8"), convert=False, na_value=na_value
337-
)
338-
result = pa.array(
339-
result, mask=mask, type=pa.large_string(), from_pandas=True
340-
)
341-
return type(self)(result)
342376
else:
343-
# This is when the result type is object. We reach this when
344-
# -> We know the result type is truly object (e.g. .encode returns bytes
345-
# or .findall returns a list).
346-
# -> We don't know the result type. E.g. `.get` can return anything.
347-
return lib.map_infer_mask(arr, f, mask.view("uint8"))
377+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
348378

349379
def _str_contains(
350380
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -615,61 +645,6 @@ def __getattribute__(self, item):
615645
return partial(getattr(ArrowStringArrayMixin, item), self)
616646
return super().__getattribute__(item)
617647

618-
def _str_map(
619-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
620-
):
621-
if dtype is None:
622-
dtype = self.dtype
623-
if na_value is None:
624-
na_value = self.dtype.na_value
625-
626-
mask = isna(self)
627-
arr = np.asarray(self)
628-
629-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
630-
if is_integer_dtype(dtype):
631-
na_value = np.nan
632-
else:
633-
na_value = False
634-
try:
635-
result = lib.map_infer_mask(
636-
arr,
637-
f,
638-
mask.view("uint8"),
639-
convert=False,
640-
na_value=na_value,
641-
dtype=np.dtype(dtype), # type: ignore[arg-type]
642-
)
643-
return result
644-
645-
except ValueError:
646-
result = lib.map_infer_mask(
647-
arr,
648-
f,
649-
mask.view("uint8"),
650-
convert=False,
651-
na_value=na_value,
652-
)
653-
if convert and result.dtype == object:
654-
result = lib.maybe_convert_objects(result)
655-
return result
656-
657-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
658-
# i.e. StringDtype
659-
result = lib.map_infer_mask(
660-
arr, f, mask.view("uint8"), convert=False, na_value=na_value
661-
)
662-
result = pa.array(
663-
result, mask=mask, type=pa.large_string(), from_pandas=True
664-
)
665-
return type(self)(result)
666-
else:
667-
# This is when the result type is object. We reach this when
668-
# -> We know the result type is truly object (e.g. .encode returns bytes
669-
# or .findall returns a list).
670-
# -> We don't know the result type. E.g. `.get` can return anything.
671-
return lib.map_infer_mask(arr, f, mask.view("uint8"))
672-
673648
def _convert_int_dtype(self, result):
674649
if isinstance(result, pa.Array):
675650
result = result.to_numpy(zero_copy_only=False)

0 commit comments

Comments
 (0)