Skip to content

REF (string dtype): de-duplicate _str_map methods #59443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 78 additions & 60 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ class BaseStringArray(ExtensionArray):
Mixin class for StringArray, ArrowStringArray.
"""

dtype: StringDtype

@doc(ExtensionArray.tolist)
def tolist(self) -> list:
if self.ndim > 1:
Expand All @@ -332,6 +334,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)

def _str_map_str_or_object(
self,
dtype,
na_value,
arr: np.ndarray,
f,
mask: npt.NDArray[np.bool_],
convert: bool,
):
# _str_map helper for case where dtype is either string dtype or object
if is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
if self.dtype.storage == "pyarrow":
import pyarrow as pa

result = pa.array(
result, mask=mask, type=pa.large_string(), from_pandas=True
)
# error: Too many arguments for "BaseStringArray"
return type(self)(result) # type: ignore[call-arg]

else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))


# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
Expand Down Expand Up @@ -697,9 +730,53 @@ def _cmp_method(self, other, op):
# base class "NumpyExtensionArray" defined the type as "float")
_str_na_value = libmissing.NA # type: ignore[assignment]

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
convert = convert and not np.all(mask)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
na_value_is_na = isna(na_value)
if na_value_is_na:
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True

result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result[mask] = np.nan
return result

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
f, na_value=na_value, dtype=dtype, convert=convert
)

from pandas.arrays import BooleanArray

if dtype is None:
Expand Down Expand Up @@ -739,18 +816,8 @@ def _str_map(

return constructor(result, mask)

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
return StringArray(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)


class StringArrayNumpySemantics(StringArray):
Expand Down Expand Up @@ -817,52 +884,3 @@ def value_counts(self, dropna: bool = True) -> Series:
# ------------------------------------------------------------------------
# String methods interface
_str_na_value = np.nan

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
convert = convert and not np.all(mask)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
na_value_is_na = isna(na_value)
if na_value_is_na:
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True

result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result[mask] = np.nan
return result

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))
113 changes: 45 additions & 68 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from pandas.core.dtypes.common import (
is_bool_dtype,
is_integer_dtype,
is_object_dtype,
is_scalar,
is_string_dtype,
pandas_dtype,
)
from pandas.core.dtypes.missing import isna
Expand Down Expand Up @@ -281,9 +279,53 @@ def astype(self, dtype, copy: bool = True):
# base class "ObjectStringArrayMixin" defined the type as "float")
_str_na_value = libmissing.NA # type: ignore[assignment]

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
if is_integer_dtype(dtype):
na_value = np.nan
else:
na_value = False

dtype = np.dtype(cast(type, dtype))
if mask.any():
# numpy int/bool dtypes cannot hold NaNs so we must convert to
# float64 for int (to match maybe_convert_objects) or
# object for bool (again to match maybe_convert_objects)
if is_integer_dtype(dtype):
dtype = np.dtype("float64")
else:
dtype = np.dtype(object)
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=dtype,
)
return result

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
f, na_value=na_value, dtype=dtype, convert=convert
)

# TODO: de-duplicate with StringArray method. This method is moreless copy and
# paste.

Expand Down Expand Up @@ -327,21 +369,8 @@ def _str_map(

return constructor(result, mask)

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
result = pa.array(
result, mask=mask, type=pa.large_string(), from_pandas=True
)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
Expand Down Expand Up @@ -614,58 +643,6 @@ def __getattribute__(self, item):
return partial(getattr(ArrowStringArrayMixin, item), self)
return super().__getattribute__(item)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
if is_integer_dtype(dtype):
na_value = np.nan
else:
na_value = False

dtype = np.dtype(cast(type, dtype))
if mask.any():
# numpy int/bool dtypes cannot hold NaNs so we must convert to
# float64 for int (to match maybe_convert_objects) or
# object for bool (again to match maybe_convert_objects)
if is_integer_dtype(dtype):
dtype = np.dtype("float64")
else:
dtype = np.dtype(object)
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=dtype,
)
return result

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
result = pa.array(
result, mask=mask, type=pa.large_string(), from_pandas=True
)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _convert_int_dtype(self, result):
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
Expand Down