Skip to content

Commit 6752935

Browse files
authored
REF (string): rename result converter methods (#59626)
1 parent 5ad25d0 commit 6752935

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ class ArrowStringArrayMixin:
2323
def __init__(self, *args, **kwargs) -> None:
2424
raise NotImplementedError
2525

26+
def _convert_bool_result(self, result):
27+
# Convert a bool-dtype result to the appropriate result type
28+
raise NotImplementedError
29+
30+
def _convert_int_result(self, result):
31+
# Convert an integer-dtype result to the appropriate result type
32+
raise NotImplementedError
33+
2634
def _str_pad(
2735
self,
2836
width: int,

pandas/core/arrays/arrow/array.py

+6
Original file line numberDiff line numberDiff line change
@@ -2311,6 +2311,12 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
23112311
for chunk in self._pa_array.iterchunks()
23122312
]
23132313

2314+
def _convert_bool_result(self, result):
2315+
return type(self)(result)
2316+
2317+
def _convert_int_result(self, result):
2318+
return type(self)(result)
2319+
23142320
def _str_count(self, pat: str, flags: int = 0) -> Self:
23152321
if flags:
23162322
raise NotImplementedError(f"count not implemented with {flags=}")

pandas/core/arrays/string_arrow.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def insert(self, loc: int, item) -> ArrowStringArray:
221221
raise TypeError("Scalar must be NA or str")
222222
return super().insert(loc, item)
223223

224-
def _result_converter(self, values, na=None):
224+
def _convert_bool_result(self, values, na=None):
225225
if self.dtype.na_value is np.nan:
226226
if not isna(na):
227227
values = values.fill_null(bool(na))
@@ -293,7 +293,7 @@ def _str_contains(
293293
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
294294
else:
295295
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
296-
result = self._result_converter(result, na=na)
296+
result = self._convert_bool_result(result, na=na)
297297
if not isna(na):
298298
result[isna(result)] = bool(na)
299299
return result
@@ -315,7 +315,7 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
315315
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
316316
if not isna(na):
317317
result = result.fill_null(na)
318-
return self._result_converter(result)
318+
return self._convert_bool_result(result)
319319

320320
def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
321321
if isinstance(pat, str):
@@ -334,7 +334,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
334334
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
335335
if not isna(na):
336336
result = result.fill_null(na)
337-
return self._result_converter(result)
337+
return self._convert_bool_result(result)
338338

339339
def _str_replace(
340340
self,
@@ -387,43 +387,43 @@ def _str_slice(
387387

388388
def _str_isalnum(self):
389389
result = pc.utf8_is_alnum(self._pa_array)
390-
return self._result_converter(result)
390+
return self._convert_bool_result(result)
391391

392392
def _str_isalpha(self):
393393
result = pc.utf8_is_alpha(self._pa_array)
394-
return self._result_converter(result)
394+
return self._convert_bool_result(result)
395395

396396
def _str_isdecimal(self):
397397
result = pc.utf8_is_decimal(self._pa_array)
398-
return self._result_converter(result)
398+
return self._convert_bool_result(result)
399399

400400
def _str_isdigit(self):
401401
result = pc.utf8_is_digit(self._pa_array)
402-
return self._result_converter(result)
402+
return self._convert_bool_result(result)
403403

404404
def _str_islower(self):
405405
result = pc.utf8_is_lower(self._pa_array)
406-
return self._result_converter(result)
406+
return self._convert_bool_result(result)
407407

408408
def _str_isnumeric(self):
409409
result = pc.utf8_is_numeric(self._pa_array)
410-
return self._result_converter(result)
410+
return self._convert_bool_result(result)
411411

412412
def _str_isspace(self):
413413
result = pc.utf8_is_space(self._pa_array)
414-
return self._result_converter(result)
414+
return self._convert_bool_result(result)
415415

416416
def _str_istitle(self):
417417
result = pc.utf8_is_title(self._pa_array)
418-
return self._result_converter(result)
418+
return self._convert_bool_result(result)
419419

420420
def _str_isupper(self):
421421
result = pc.utf8_is_upper(self._pa_array)
422-
return self._result_converter(result)
422+
return self._convert_bool_result(result)
423423

424424
def _str_len(self):
425425
result = pc.utf8_length(self._pa_array)
426-
return self._convert_int_dtype(result)
426+
return self._convert_int_result(result)
427427

428428
def _str_lower(self) -> Self:
429429
return type(self)(pc.utf8_lower(self._pa_array))
@@ -470,7 +470,7 @@ def _str_count(self, pat: str, flags: int = 0):
470470
if flags:
471471
return super()._str_count(pat, flags)
472472
result = pc.count_substring_regex(self._pa_array, pat)
473-
return self._convert_int_dtype(result)
473+
return self._convert_int_result(result)
474474

475475
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
476476
if start != 0 and end is not None:
@@ -484,7 +484,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
484484
result = pc.find_substring(slices, sub)
485485
else:
486486
return super()._str_find(sub, start, end)
487-
return self._convert_int_dtype(result)
487+
return self._convert_int_result(result)
488488

489489
def _str_get_dummies(self, sep: str = "|"):
490490
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
@@ -493,7 +493,7 @@ def _str_get_dummies(self, sep: str = "|"):
493493
dummies = np.vstack(dummies_pa.to_numpy())
494494
return dummies.astype(np.int64, copy=False), labels
495495

496-
def _convert_int_dtype(self, result):
496+
def _convert_int_result(self, result):
497497
if self.dtype.na_value is np.nan:
498498
if isinstance(result, pa.Array):
499499
result = result.to_numpy(zero_copy_only=False)
@@ -520,7 +520,7 @@ def _reduce(
520520

521521
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
522522
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
523-
return self._convert_int_dtype(result)
523+
return self._convert_int_result(result)
524524
elif isinstance(result, pa.Array):
525525
return type(self)(result)
526526
else:
@@ -538,7 +538,7 @@ def _rank(
538538
"""
539539
See Series.rank.__doc__.
540540
"""
541-
return self._convert_int_dtype(
541+
return self._convert_int_result(
542542
self._rank_calc(
543543
axis=axis,
544544
method=method,

0 commit comments

Comments
 (0)