Skip to content

REF (string): rename result converter methods #59626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class ArrowStringArrayMixin:
def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _convert_bool_result(self, result):
# Convert a bool-dtype result to the appropriate result type
raise NotImplementedError

def _convert_int_result(self, result):
# Convert an integer-dtype result to the appropriate result type
raise NotImplementedError

def _str_pad(
self,
width: int,
Expand Down
6 changes: 6 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,12 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
for chunk in self._pa_array.iterchunks()
]

def _convert_bool_result(self, result):
return type(self)(result)

def _convert_int_result(self, result):
return type(self)(result)

def _str_count(self, pat: str, flags: int = 0) -> Self:
if flags:
raise NotImplementedError(f"count not implemented with {flags=}")
Expand Down
38 changes: 19 additions & 19 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

def _result_converter(self, values, na=None):
def _convert_bool_result(self, values, na=None):
if self.dtype.na_value is np.nan:
if not isna(na):
values = values.fill_null(bool(na))
Expand Down Expand Up @@ -293,7 +293,7 @@ def _str_contains(
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
else:
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
result = self._result_converter(result, na=na)
result = self._convert_bool_result(result, na=na)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand All @@ -315,7 +315,7 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
Expand All @@ -334,7 +334,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_replace(
self,
Expand Down Expand Up @@ -389,43 +389,43 @@ def _str_slice(

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isalpha(self):
result = pc.utf8_is_alpha(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isdecimal(self):
result = pc.utf8_is_decimal(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isdigit(self):
result = pc.utf8_is_digit(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_islower(self):
result = pc.utf8_is_lower(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isnumeric(self):
result = pc.utf8_is_numeric(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isspace(self):
result = pc.utf8_is_space(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_istitle(self):
result = pc.utf8_is_title(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_dtype(result)
return self._convert_int_result(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))
Expand Down Expand Up @@ -472,7 +472,7 @@ def _str_count(self, pat: str, flags: int = 0):
if flags:
return super()._str_count(pat, flags)
result = pc.count_substring_regex(self._pa_array, pat)
return self._convert_int_dtype(result)
return self._convert_int_result(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
Expand All @@ -486,7 +486,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
result = pc.find_substring(slices, sub)
else:
return super()._str_find(sub, start, end)
return self._convert_int_dtype(result)
return self._convert_int_result(result)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
Expand All @@ -495,7 +495,7 @@ def _str_get_dummies(self, sep: str = "|"):
dummies = np.vstack(dummies_pa.to_numpy())
return dummies.astype(np.int64, copy=False), labels

def _convert_int_dtype(self, result):
def _convert_int_result(self, result):
if self.dtype.na_value is np.nan:
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
Expand All @@ -522,7 +522,7 @@ def _reduce(

result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
return self._convert_int_dtype(result)
return self._convert_int_result(result)
elif isinstance(result, pa.Array):
return type(self)(result)
else:
Expand All @@ -540,7 +540,7 @@ def _rank(
"""
See Series.rank.__doc__.
"""
return self._convert_int_dtype(
return self._convert_int_result(
self._rank_calc(
axis=axis,
method=method,
Expand Down