Skip to content

Commit 40d81db

Browse files
jbrockmendeljorisvandenbossche
authored andcommitted
REF (string): de-duplicate _str_contains (#59709)
* REF: de-duplicate _str_contains * pyright ignore
1 parent 5b571c0 commit 40d81db

File tree

3 files changed

+19
-25
lines changed

3 files changed

+19
-25
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+15
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,18 @@ def _str_istitle(self):
186186
def _str_isupper(self):
187187
result = pc.utf8_is_upper(self._pa_array)
188188
return self._convert_bool_result(result)
189+
190+
def _str_contains(
191+
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
192+
):
193+
if flags:
194+
raise NotImplementedError(f"contains not implemented with {flags=}")
195+
196+
if regex:
197+
pa_contains = pc.match_substring_regex
198+
else:
199+
pa_contains = pc.match_substring
200+
result = pa_contains(self._pa_array, pat, ignore_case=not case)
201+
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
202+
result = result.fill_null(na)
203+
return self._convert_bool_result(result)

pandas/core/arrays/arrow/array.py

-15
Original file line numberDiff line numberDiff line change
@@ -2296,21 +2296,6 @@ def _str_count(self, pat: str, flags: int = 0):
22962296
raise NotImplementedError(f"count not implemented with {flags=}")
22972297
return type(self)(pc.count_substring_regex(self._pa_array, pat))
22982298

2299-
def _str_contains(
2300-
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
2301-
):
2302-
if flags:
2303-
raise NotImplementedError(f"contains not implemented with {flags=}")
2304-
2305-
if regex:
2306-
pa_contains = pc.match_substring_regex
2307-
else:
2308-
pa_contains = pc.match_substring
2309-
result = pa_contains(self._pa_array, pat, ignore_case=not case)
2310-
if not isna(na):
2311-
result = result.fill_null(na)
2312-
return type(self)(result)
2313-
23142299
def _result_converter(self, result):
23152300
return type(self)(result)
23162301

pandas/core/arrays/string_arrow.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,8 @@ def insert(self, loc: int, item) -> ArrowStringArray:
214214
raise TypeError("Scalar must be NA or str")
215215
return super().insert(loc, item)
216216

217-
def _convert_bool_result(self, values, na=None):
217+
def _convert_bool_result(self, values):
218218
if self.dtype.na_value is np.nan:
219-
if not isna(na):
220-
values = values.fill_null(bool(na))
221219
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
222220
return BooleanDtype().__from_arrow__(values)
223221

@@ -305,11 +303,6 @@ def _str_contains(
305303
fallback_performancewarning()
306304
return super()._str_contains(pat, case, flags, na, regex)
307305

308-
if regex:
309-
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
310-
else:
311-
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
312-
result = self._convert_bool_result(result, na=na)
313306
if not isna(na):
314307
if not isinstance(na, bool):
315308
# GH#59561
@@ -319,8 +312,9 @@ def _str_contains(
319312
FutureWarning,
320313
stacklevel=find_stack_level(),
321314
)
322-
result[isna(result)] = bool(na)
323-
return result
315+
na = bool(na)
316+
317+
return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)
324318

325319
def _str_replace(
326320
self,

0 commit comments

Comments
 (0)