Skip to content

Commit 00bf889

Browse files
authored
ENH: Use more arrow compute functions for string[pyarrow] dtype (#54957)
1 parent 5b02305 commit 00bf889

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

pandas/core/arrays/string_arrow.py

+25-25
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def _str_isupper(self):
417417

418418
def _str_len(self):
419419
result = pc.utf8_length(self._pa_array)
420-
return Int64Dtype().__from_arrow__(result)
420+
return self._convert_int_dtype(result)
421421

422422
def _str_lower(self):
423423
return type(self)(pc.utf8_lower(self._pa_array))
@@ -446,6 +446,29 @@ def _str_rstrip(self, to_strip=None):
446446
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
447447
return type(self)(result)
448448

449+
def _str_count(self, pat: str, flags: int = 0):
450+
if flags:
451+
return super()._str_count(pat, flags)
452+
result = pc.count_substring_regex(self._pa_array, pat)
453+
return self._convert_int_dtype(result)
454+
455+
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
456+
if start != 0 and end is not None:
457+
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
458+
result = pc.find_substring(slices, sub)
459+
not_found = pc.equal(result, -1)
460+
offset_result = pc.add(result, end - start)
461+
result = pc.if_else(not_found, result, offset_result)
462+
elif start == 0 and end is None:
463+
slices = self._pa_array
464+
result = pc.find_substring(slices, sub)
465+
else:
466+
return super()._str_find(sub, start, end)
467+
return self._convert_int_dtype(result)
468+
469+
def _convert_int_dtype(self, result):
470+
return Int64Dtype().__from_arrow__(result)
471+
449472

450473
class ArrowStringArrayNumpySemantics(ArrowStringArray):
451474
_storage = "pyarrow_numpy"
@@ -526,34 +549,11 @@ def _str_map(
526549
return lib.map_infer_mask(arr, f, mask.view("uint8"))
527550

528551
def _convert_int_dtype(self, result):
552+
result = result.to_numpy()
529553
if result.dtype == np.int32:
530554
result = result.astype(np.int64)
531555
return result
532556

533-
def _str_count(self, pat: str, flags: int = 0):
534-
if flags:
535-
return super()._str_count(pat, flags)
536-
result = pc.count_substring_regex(self._pa_array, pat).to_numpy()
537-
return self._convert_int_dtype(result)
538-
539-
def _str_len(self):
540-
result = pc.utf8_length(self._pa_array).to_numpy()
541-
return self._convert_int_dtype(result)
542-
543-
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
544-
if start != 0 and end is not None:
545-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
546-
result = pc.find_substring(slices, sub)
547-
not_found = pc.equal(result, -1)
548-
offset_result = pc.add(result, end - start)
549-
result = pc.if_else(not_found, result, offset_result)
550-
elif start == 0 and end is None:
551-
slices = self._pa_array
552-
result = pc.find_substring(slices, sub)
553-
else:
554-
return super()._str_find(sub, start, end)
555-
return self._convert_int_dtype(result.to_numpy())
556-
557557
def _cmp_method(self, other, op):
558558
result = super()._cmp_method(other, op)
559559
return result.to_numpy(np.bool_, na_value=False)

0 commit comments

Comments
 (0)