Skip to content

Commit 3a7d82d

Browse files
authored
REF (string dtype): de-duplicate _str_map (2) (#59451)
* REF (string): de-duplicate _str_map (2) * mypy fixup
1 parent 27c326a commit 3a7d82d

File tree

2 files changed

+92
-143
lines changed

2 files changed

+92
-143
lines changed

pandas/core/arrays/string_.py

+90-89
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,57 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
346346
raise ValueError
347347
return cls._from_sequence(scalars, dtype=dtype)
348348

349+
def _str_map(
350+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
351+
):
352+
if self.dtype.na_value is np.nan:
353+
return self._str_map_nan_semantics(
354+
f, na_value=na_value, dtype=dtype, convert=convert
355+
)
356+
357+
from pandas.arrays import BooleanArray
358+
359+
if dtype is None:
360+
dtype = self.dtype
361+
if na_value is None:
362+
na_value = self.dtype.na_value
363+
364+
mask = isna(self)
365+
arr = np.asarray(self)
366+
367+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
368+
constructor: type[IntegerArray | BooleanArray]
369+
if is_integer_dtype(dtype):
370+
constructor = IntegerArray
371+
else:
372+
constructor = BooleanArray
373+
374+
na_value_is_na = isna(na_value)
375+
if na_value_is_na:
376+
na_value = 1
377+
elif dtype == np.dtype("bool"):
378+
# GH#55736
379+
na_value = bool(na_value)
380+
result = lib.map_infer_mask(
381+
arr,
382+
f,
383+
mask.view("uint8"),
384+
convert=False,
385+
na_value=na_value,
386+
# error: Argument 1 to "dtype" has incompatible type
387+
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
388+
# "Type[object]"
389+
dtype=np.dtype(cast(type, dtype)),
390+
)
391+
392+
if not na_value_is_na:
393+
mask[:] = False
394+
395+
return constructor(result, mask)
396+
397+
else:
398+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
399+
349400
def _str_map_str_or_object(
350401
self,
351402
dtype,
@@ -377,6 +428,45 @@ def _str_map_str_or_object(
377428
# -> We don't know the result type. E.g. `.get` can return anything.
378429
return lib.map_infer_mask(arr, f, mask.view("uint8"))
379430

431+
def _str_map_nan_semantics(
432+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
433+
):
434+
if dtype is None:
435+
dtype = self.dtype
436+
if na_value is None:
437+
na_value = self.dtype.na_value
438+
439+
mask = isna(self)
440+
arr = np.asarray(self)
441+
convert = convert and not np.all(mask)
442+
443+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
444+
na_value_is_na = isna(na_value)
445+
if na_value_is_na:
446+
if is_integer_dtype(dtype):
447+
na_value = 0
448+
else:
449+
na_value = True
450+
451+
result = lib.map_infer_mask(
452+
arr,
453+
f,
454+
mask.view("uint8"),
455+
convert=False,
456+
na_value=na_value,
457+
dtype=np.dtype(cast(type, dtype)),
458+
)
459+
if na_value_is_na and mask.any():
460+
if is_integer_dtype(dtype):
461+
result = result.astype("float64")
462+
else:
463+
result = result.astype("object")
464+
result[mask] = np.nan
465+
return result
466+
467+
else:
468+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
469+
380470

381471
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
382472
# incompatible with definition in base class "ExtensionArray"
@@ -742,95 +832,6 @@ def _cmp_method(self, other, op):
742832
# base class "NumpyExtensionArray" defined the type as "float")
743833
_str_na_value = libmissing.NA # type: ignore[assignment]
744834

745-
def _str_map_nan_semantics(
746-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
747-
):
748-
if dtype is None:
749-
dtype = self.dtype
750-
if na_value is None:
751-
na_value = self.dtype.na_value
752-
753-
mask = isna(self)
754-
arr = np.asarray(self)
755-
convert = convert and not np.all(mask)
756-
757-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
758-
na_value_is_na = isna(na_value)
759-
if na_value_is_na:
760-
if is_integer_dtype(dtype):
761-
na_value = 0
762-
else:
763-
na_value = True
764-
765-
result = lib.map_infer_mask(
766-
arr,
767-
f,
768-
mask.view("uint8"),
769-
convert=False,
770-
na_value=na_value,
771-
dtype=np.dtype(cast(type, dtype)),
772-
)
773-
if na_value_is_na and mask.any():
774-
if is_integer_dtype(dtype):
775-
result = result.astype("float64")
776-
else:
777-
result = result.astype("object")
778-
result[mask] = np.nan
779-
return result
780-
781-
else:
782-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
783-
784-
def _str_map(
785-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
786-
):
787-
if self.dtype.na_value is np.nan:
788-
return self._str_map_nan_semantics(
789-
f, na_value=na_value, dtype=dtype, convert=convert
790-
)
791-
792-
from pandas.arrays import BooleanArray
793-
794-
if dtype is None:
795-
dtype = StringDtype(storage="python")
796-
if na_value is None:
797-
na_value = self.dtype.na_value
798-
799-
mask = isna(self)
800-
arr = np.asarray(self)
801-
802-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
803-
constructor: type[IntegerArray | BooleanArray]
804-
if is_integer_dtype(dtype):
805-
constructor = IntegerArray
806-
else:
807-
constructor = BooleanArray
808-
809-
na_value_is_na = isna(na_value)
810-
if na_value_is_na:
811-
na_value = 1
812-
elif dtype == np.dtype("bool"):
813-
na_value = bool(na_value)
814-
result = lib.map_infer_mask(
815-
arr,
816-
f,
817-
mask.view("uint8"),
818-
convert=False,
819-
na_value=na_value,
820-
# error: Argument 1 to "dtype" has incompatible type
821-
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
822-
# "Type[object]"
823-
dtype=np.dtype(cast(type, dtype)),
824-
)
825-
826-
if not na_value_is_na:
827-
mask[:] = False
828-
829-
return constructor(result, mask)
830-
831-
else:
832-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
833-
834835

835836
class StringArrayNumpySemantics(StringArray):
836837
_storage = "python"

pandas/core/arrays/string_arrow.py

+2-54
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ def astype(self, dtype, copy: bool = True):
279279
# base class "ObjectStringArrayMixin" defined the type as "float")
280280
_str_na_value = libmissing.NA # type: ignore[assignment]
281281

282+
_str_map = BaseStringArray._str_map
283+
282284
def _str_map_nan_semantics(
283285
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
284286
):
@@ -318,60 +320,6 @@ def _str_map_nan_semantics(
318320
else:
319321
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
320322

321-
def _str_map(
322-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
323-
):
324-
if self.dtype.na_value is np.nan:
325-
return self._str_map_nan_semantics(
326-
f, na_value=na_value, dtype=dtype, convert=convert
327-
)
328-
329-
# TODO: de-duplicate with StringArray method. This method is moreless copy and
330-
# paste.
331-
332-
from pandas.arrays import (
333-
BooleanArray,
334-
IntegerArray,
335-
)
336-
337-
if dtype is None:
338-
dtype = self.dtype
339-
if na_value is None:
340-
na_value = self.dtype.na_value
341-
342-
mask = isna(self)
343-
arr = np.asarray(self)
344-
345-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
346-
constructor: type[IntegerArray | BooleanArray]
347-
if is_integer_dtype(dtype):
348-
constructor = IntegerArray
349-
else:
350-
constructor = BooleanArray
351-
352-
na_value_is_na = isna(na_value)
353-
if na_value_is_na:
354-
na_value = 1
355-
result = lib.map_infer_mask(
356-
arr,
357-
f,
358-
mask.view("uint8"),
359-
convert=False,
360-
na_value=na_value,
361-
# error: Argument 1 to "dtype" has incompatible type
362-
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
363-
# "Type[object]"
364-
dtype=np.dtype(cast(type, dtype)),
365-
)
366-
367-
if not na_value_is_na:
368-
mask[:] = False
369-
370-
return constructor(result, mask)
371-
372-
else:
373-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
374-
375323
def _str_contains(
376324
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
377325
):

0 commit comments

Comments
 (0)