Skip to content

Commit 972369f

Browse files
jbrockmendeljorisvandenbossche
authored andcommitted
REF (string): rename result converter methods (pandas-dev#59626)
1 parent 62b474b commit 972369f

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ class ArrowStringArrayMixin:
1717
def __init__(self, *args, **kwargs) -> None:
1818
raise NotImplementedError
1919

20+
def _convert_bool_result(self, result):
21+
# Convert a bool-dtype result to the appropriate result type
22+
raise NotImplementedError
23+
24+
def _convert_int_result(self, result):
25+
# Convert an integer-dtype result to the appropriate result type
26+
raise NotImplementedError
27+
2028
def _str_pad(
2129
self,
2230
width: int,

pandas/core/arrays/arrow/array.py

+6
Original file line numberDiff line numberDiff line change
@@ -2285,6 +2285,12 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
22852285
for chunk in self._pa_array.iterchunks()
22862286
]
22872287

2288+
def _convert_bool_result(self, result):
2289+
return type(self)(result)
2290+
2291+
def _convert_int_result(self, result):
2292+
return type(self)(result)
2293+
22882294
def _str_count(self, pat: str, flags: int = 0):
22892295
if flags:
22902296
raise NotImplementedError(f"count not implemented with {flags=}")

pandas/core/arrays/string_arrow.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def insert(self, loc: int, item) -> ArrowStringArray:
214214
raise TypeError("Scalar must be NA or str")
215215
return super().insert(loc, item)
216216

217-
def _result_converter(self, values, na=None):
217+
def _convert_bool_result(self, values, na=None):
218218
if self.dtype.na_value is np.nan:
219219
if not isna(na):
220220
values = values.fill_null(bool(na))
@@ -296,7 +296,7 @@ def _str_contains(
296296
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
297297
else:
298298
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
299-
result = self._result_converter(result, na=na)
299+
result = self._convert_bool_result(result, na=na)
300300
if not isna(na):
301301
result[isna(result)] = bool(na)
302302
return result
@@ -318,7 +318,7 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
318318
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
319319
if not isna(na):
320320
result = result.fill_null(na)
321-
return self._result_converter(result)
321+
return self._convert_bool_result(result)
322322

323323
def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
324324
if isinstance(pat, str):
@@ -337,7 +337,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
337337
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
338338
if not isna(na):
339339
result = result.fill_null(na)
340-
return self._result_converter(result)
340+
return self._convert_bool_result(result)
341341

342342
def _str_replace(
343343
self,
@@ -389,43 +389,43 @@ def _str_slice(
389389

390390
def _str_isalnum(self):
391391
result = pc.utf8_is_alnum(self._pa_array)
392-
return self._result_converter(result)
392+
return self._convert_bool_result(result)
393393

394394
def _str_isalpha(self):
395395
result = pc.utf8_is_alpha(self._pa_array)
396-
return self._result_converter(result)
396+
return self._convert_bool_result(result)
397397

398398
def _str_isdecimal(self):
399399
result = pc.utf8_is_decimal(self._pa_array)
400-
return self._result_converter(result)
400+
return self._convert_bool_result(result)
401401

402402
def _str_isdigit(self):
403403
result = pc.utf8_is_digit(self._pa_array)
404-
return self._result_converter(result)
404+
return self._convert_bool_result(result)
405405

406406
def _str_islower(self):
407407
result = pc.utf8_is_lower(self._pa_array)
408-
return self._result_converter(result)
408+
return self._convert_bool_result(result)
409409

410410
def _str_isnumeric(self):
411411
result = pc.utf8_is_numeric(self._pa_array)
412-
return self._result_converter(result)
412+
return self._convert_bool_result(result)
413413

414414
def _str_isspace(self):
415415
result = pc.utf8_is_space(self._pa_array)
416-
return self._result_converter(result)
416+
return self._convert_bool_result(result)
417417

418418
def _str_istitle(self):
419419
result = pc.utf8_is_title(self._pa_array)
420-
return self._result_converter(result)
420+
return self._convert_bool_result(result)
421421

422422
def _str_isupper(self):
423423
result = pc.utf8_is_upper(self._pa_array)
424-
return self._result_converter(result)
424+
return self._convert_bool_result(result)
425425

426426
def _str_len(self):
427427
result = pc.utf8_length(self._pa_array)
428-
return self._convert_int_dtype(result)
428+
return self._convert_int_result(result)
429429

430430
def _str_lower(self):
431431
return type(self)(pc.utf8_lower(self._pa_array))
@@ -472,7 +472,7 @@ def _str_count(self, pat: str, flags: int = 0):
472472
if flags:
473473
return super()._str_count(pat, flags)
474474
result = pc.count_substring_regex(self._pa_array, pat)
475-
return self._convert_int_dtype(result)
475+
return self._convert_int_result(result)
476476

477477
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
478478
if start != 0 and end is not None:
@@ -486,7 +486,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
486486
result = pc.find_substring(slices, sub)
487487
else:
488488
return super()._str_find(sub, start, end)
489-
return self._convert_int_dtype(result)
489+
return self._convert_int_result(result)
490490

491491
def _str_get_dummies(self, sep: str = "|"):
492492
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
@@ -495,7 +495,7 @@ def _str_get_dummies(self, sep: str = "|"):
495495
dummies = np.vstack(dummies_pa.to_numpy())
496496
return dummies.astype(np.int64, copy=False), labels
497497

498-
def _convert_int_dtype(self, result):
498+
def _convert_int_result(self, result):
499499
if self.dtype.na_value is np.nan:
500500
if isinstance(result, pa.Array):
501501
result = result.to_numpy(zero_copy_only=False)
@@ -522,7 +522,7 @@ def _reduce(
522522

523523
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
524524
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
525-
return self._convert_int_dtype(result)
525+
return self._convert_int_result(result)
526526
elif isinstance(result, pa.Array):
527527
return type(self)(result)
528528
else:
@@ -540,7 +540,7 @@ def _rank(
540540
"""
541541
See Series.rank.__doc__.
542542
"""
543-
return self._convert_int_dtype(
543+
return self._convert_int_result(
544544
self._rank_calc(
545545
axis=axis,
546546
method=method,

0 commit comments

Comments
 (0)