Skip to content

Commit 6cb42b8

Browse files
committed
REF: de-duplicate _str_contains
1 parent 4f1052e commit 6cb42b8

File tree

3 files changed

+19
-25
lines changed

3 files changed

+19
-25
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+15
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,18 @@ def _str_istitle(self):
190190
def _str_isupper(self):
191191
result = pc.utf8_is_upper(self._pa_array)
192192
return self._convert_bool_result(result)
193+
194+
def _str_contains(
195+
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
196+
):
197+
if flags:
198+
raise NotImplementedError(f"contains not implemented with {flags=}")
199+
200+
if regex:
201+
pa_contains = pc.match_substring_regex
202+
else:
203+
pa_contains = pc.match_substring
204+
result = pa_contains(self._pa_array, pat, ignore_case=not case)
205+
if not isna(na):
206+
result = result.fill_null(na)
207+
return self._convert_bool_result(result)

pandas/core/arrays/arrow/array.py

-15
Original file line numberDiff line numberDiff line change
@@ -2322,21 +2322,6 @@ def _str_count(self, pat: str, flags: int = 0) -> Self:
23222322
raise NotImplementedError(f"count not implemented with {flags=}")
23232323
return type(self)(pc.count_substring_regex(self._pa_array, pat))
23242324

2325-
def _str_contains(
2326-
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
2327-
) -> Self:
2328-
if flags:
2329-
raise NotImplementedError(f"contains not implemented with {flags=}")
2330-
2331-
if regex:
2332-
pa_contains = pc.match_substring_regex
2333-
else:
2334-
pa_contains = pc.match_substring
2335-
result = pa_contains(self._pa_array, pat, ignore_case=not case)
2336-
if not isna(na):
2337-
result = result.fill_null(na)
2338-
return type(self)(result)
2339-
23402325
def _result_converter(self, result):
23412326
return type(self)(result)
23422327

pandas/core/arrays/string_arrow.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,8 @@ def insert(self, loc: int, item) -> ArrowStringArray:
223223
raise TypeError("Scalar must be NA or str")
224224
return super().insert(loc, item)
225225

226-
def _convert_bool_result(self, values, na=None):
226+
def _convert_bool_result(self, values):
227227
if self.dtype.na_value is np.nan:
228-
if not isna(na):
229-
values = values.fill_null(bool(na))
230228
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
231229
return BooleanDtype().__from_arrow__(values)
232230

@@ -304,11 +302,6 @@ def _str_contains(
304302
fallback_performancewarning()
305303
return super()._str_contains(pat, case, flags, na, regex)
306304

307-
if regex:
308-
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
309-
else:
310-
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
311-
result = self._convert_bool_result(result, na=na)
312305
if not isna(na):
313306
if not isinstance(na, bool):
314307
# GH#59561
@@ -318,8 +311,9 @@ def _str_contains(
318311
FutureWarning,
319312
stacklevel=find_stack_level(),
320313
)
321-
result[isna(result)] = bool(na)
322-
return result
314+
na = bool(na)
315+
316+
return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)
323317

324318
def _str_replace(
325319
self,

0 commit comments

Comments
 (0)