Skip to content

Commit 6b8c9f9

Browse files
jbrockmendelWillAyd
authored andcommitted
REF (string dtype): de-duplicate _str_map (2) (pandas-dev#59451)
* REF (string): de-duplicate _str_map (2) * mypy fixup
1 parent b737098 commit 6b8c9f9

File tree

2 files changed

+92
-143
lines changed

2 files changed

+92
-143
lines changed

pandas/core/arrays/string_.py

+90-89
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,57 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
342342
raise ValueError
343343
return cls._from_sequence(scalars, dtype=dtype)
344344

345+
def _str_map(
346+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
347+
):
348+
if self.dtype.na_value is np.nan:
349+
return self._str_map_nan_semantics(
350+
f, na_value=na_value, dtype=dtype, convert=convert
351+
)
352+
353+
from pandas.arrays import BooleanArray
354+
355+
if dtype is None:
356+
dtype = self.dtype
357+
if na_value is None:
358+
na_value = self.dtype.na_value
359+
360+
mask = isna(self)
361+
arr = np.asarray(self)
362+
363+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
364+
constructor: type[IntegerArray | BooleanArray]
365+
if is_integer_dtype(dtype):
366+
constructor = IntegerArray
367+
else:
368+
constructor = BooleanArray
369+
370+
na_value_is_na = isna(na_value)
371+
if na_value_is_na:
372+
na_value = 1
373+
elif dtype == np.dtype("bool"):
374+
# GH#55736
375+
na_value = bool(na_value)
376+
result = lib.map_infer_mask(
377+
arr,
378+
f,
379+
mask.view("uint8"),
380+
convert=False,
381+
na_value=na_value,
382+
# error: Argument 1 to "dtype" has incompatible type
383+
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
384+
# "Type[object]"
385+
dtype=np.dtype(cast(type, dtype)),
386+
)
387+
388+
if not na_value_is_na:
389+
mask[:] = False
390+
391+
return constructor(result, mask)
392+
393+
else:
394+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
395+
345396
def _str_map_str_or_object(
346397
self,
347398
dtype,
@@ -373,6 +424,45 @@ def _str_map_str_or_object(
373424
# -> We don't know the result type. E.g. `.get` can return anything.
374425
return lib.map_infer_mask(arr, f, mask.view("uint8"))
375426

427+
def _str_map_nan_semantics(
428+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
429+
):
430+
if dtype is None:
431+
dtype = self.dtype
432+
if na_value is None:
433+
na_value = self.dtype.na_value
434+
435+
mask = isna(self)
436+
arr = np.asarray(self)
437+
convert = convert and not np.all(mask)
438+
439+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
440+
na_value_is_na = isna(na_value)
441+
if na_value_is_na:
442+
if is_integer_dtype(dtype):
443+
na_value = 0
444+
else:
445+
na_value = True
446+
447+
result = lib.map_infer_mask(
448+
arr,
449+
f,
450+
mask.view("uint8"),
451+
convert=False,
452+
na_value=na_value,
453+
dtype=np.dtype(cast(type, dtype)),
454+
)
455+
if na_value_is_na and mask.any():
456+
if is_integer_dtype(dtype):
457+
result = result.astype("float64")
458+
else:
459+
result = result.astype("object")
460+
result[mask] = np.nan
461+
return result
462+
463+
else:
464+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
465+
376466

377467
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
378468
# incompatible with definition in base class "ExtensionArray"
@@ -727,95 +817,6 @@ def _cmp_method(self, other, op):
727817
# base class "NumpyExtensionArray" defined the type as "float")
728818
_str_na_value = libmissing.NA # type: ignore[assignment]
729819

730-
def _str_map_nan_semantics(
731-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
732-
):
733-
if dtype is None:
734-
dtype = self.dtype
735-
if na_value is None:
736-
na_value = self.dtype.na_value
737-
738-
mask = isna(self)
739-
arr = np.asarray(self)
740-
convert = convert and not np.all(mask)
741-
742-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
743-
na_value_is_na = isna(na_value)
744-
if na_value_is_na:
745-
if is_integer_dtype(dtype):
746-
na_value = 0
747-
else:
748-
na_value = True
749-
750-
result = lib.map_infer_mask(
751-
arr,
752-
f,
753-
mask.view("uint8"),
754-
convert=False,
755-
na_value=na_value,
756-
dtype=np.dtype(cast(type, dtype)),
757-
)
758-
if na_value_is_na and mask.any():
759-
if is_integer_dtype(dtype):
760-
result = result.astype("float64")
761-
else:
762-
result = result.astype("object")
763-
result[mask] = np.nan
764-
return result
765-
766-
else:
767-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
768-
769-
def _str_map(
770-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
771-
):
772-
if self.dtype.na_value is np.nan:
773-
return self._str_map_nan_semantics(
774-
f, na_value=na_value, dtype=dtype, convert=convert
775-
)
776-
777-
from pandas.arrays import BooleanArray
778-
779-
if dtype is None:
780-
dtype = StringDtype(storage="python")
781-
if na_value is None:
782-
na_value = self.dtype.na_value
783-
784-
mask = isna(self)
785-
arr = np.asarray(self)
786-
787-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
788-
constructor: type[IntegerArray | BooleanArray]
789-
if is_integer_dtype(dtype):
790-
constructor = IntegerArray
791-
else:
792-
constructor = BooleanArray
793-
794-
na_value_is_na = isna(na_value)
795-
if na_value_is_na:
796-
na_value = 1
797-
elif dtype == np.dtype("bool"):
798-
na_value = bool(na_value)
799-
result = lib.map_infer_mask(
800-
arr,
801-
f,
802-
mask.view("uint8"),
803-
convert=False,
804-
na_value=na_value,
805-
# error: Argument 1 to "dtype" has incompatible type
806-
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
807-
# "Type[object]"
808-
dtype=np.dtype(dtype), # type: ignore[arg-type]
809-
)
810-
811-
if not na_value_is_na:
812-
mask[:] = False
813-
814-
return constructor(result, mask)
815-
816-
else:
817-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
818-
819820

820821
class StringArrayNumpySemantics(StringArray):
821822
_storage = "python"

pandas/core/arrays/string_arrow.py

+2-54
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ def _data(self):
283283
# base class "ObjectStringArrayMixin" defined the type as "float")
284284
_str_na_value = libmissing.NA # type: ignore[assignment]
285285

286+
_str_map = BaseStringArray._str_map
287+
286288
def _str_map_nan_semantics(
287289
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
288290
):
@@ -322,60 +324,6 @@ def _str_map_nan_semantics(
322324
else:
323325
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
324326

325-
def _str_map(
326-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
327-
):
328-
if self.dtype.na_value is np.nan:
329-
return self._str_map_nan_semantics(
330-
f, na_value=na_value, dtype=dtype, convert=convert
331-
)
332-
333-
# TODO: de-duplicate with StringArray method. This method is moreless copy and
334-
# paste.
335-
336-
from pandas.arrays import (
337-
BooleanArray,
338-
IntegerArray,
339-
)
340-
341-
if dtype is None:
342-
dtype = self.dtype
343-
if na_value is None:
344-
na_value = self.dtype.na_value
345-
346-
mask = isna(self)
347-
arr = np.asarray(self)
348-
349-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
350-
constructor: type[IntegerArray | BooleanArray]
351-
if is_integer_dtype(dtype):
352-
constructor = IntegerArray
353-
else:
354-
constructor = BooleanArray
355-
356-
na_value_is_na = isna(na_value)
357-
if na_value_is_na:
358-
na_value = 1
359-
result = lib.map_infer_mask(
360-
arr,
361-
f,
362-
mask.view("uint8"),
363-
convert=False,
364-
na_value=na_value,
365-
# error: Argument 1 to "dtype" has incompatible type
366-
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
367-
# "Type[object]"
368-
dtype=np.dtype(dtype), # type: ignore[arg-type]
369-
)
370-
371-
if not na_value_is_na:
372-
mask[:] = False
373-
374-
return constructor(result, mask)
375-
376-
else:
377-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
378-
379327
def _str_contains(
380328
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
381329
):

0 commit comments

Comments
 (0)