Skip to content

Commit a24a653

Browse files
[backport 2.3.x] String dtype: propagate NaNs as False in predicate methods (eg .str.startswith) (#59616) (#60014)
* String dtype: propagate NaNs as False in predicate methods (eg .str.startswith) (#59616) (cherry picked from commit 88554d0) * ignore object dtype inference warnings
1 parent e3302bc commit a24a653

12 files changed

+316
-146
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010

1111
import numpy as np
1212

13+
from pandas._libs import lib
1314
from pandas.compat import (
1415
pa_version_under10p1,
1516
pa_version_under11p0,
1617
pa_version_under13p0,
1718
pa_version_under17p0,
1819
)
1920

20-
from pandas.core.dtypes.missing import isna
21-
2221
if not pa_version_under10p1:
2322
import pyarrow as pa
2423
import pyarrow.compute as pc
@@ -38,7 +37,7 @@ class ArrowStringArrayMixin:
3837
def __init__(self, *args, **kwargs) -> None:
3938
raise NotImplementedError
4039

41-
def _convert_bool_result(self, result):
40+
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
4241
# Convert a bool-dtype result to the appropriate result type
4342
raise NotImplementedError
4443

@@ -212,7 +211,9 @@ def _str_removesuffix(self, suffix: str):
212211
result = pc.if_else(ends_with, removed, self._pa_array)
213212
return type(self)(result)
214213

215-
def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
214+
def _str_startswith(
215+
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
216+
):
216217
if isinstance(pat, str):
217218
result = pc.starts_with(self._pa_array, pattern=pat)
218219
else:
@@ -225,11 +226,11 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
225226

226227
for p in pat[1:]:
227228
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
228-
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
229-
result = result.fill_null(na)
230-
return self._convert_bool_result(result)
229+
return self._convert_bool_result(result, na=na, method_name="startswith")
231230

232-
def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
231+
def _str_endswith(
232+
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
233+
):
233234
if isinstance(pat, str):
234235
result = pc.ends_with(self._pa_array, pattern=pat)
235236
else:
@@ -242,9 +243,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
242243

243244
for p in pat[1:]:
244245
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
245-
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
246-
result = result.fill_null(na)
247-
return self._convert_bool_result(result)
246+
return self._convert_bool_result(result, na=na, method_name="endswith")
248247

249248
def _str_isalnum(self):
250249
result = pc.utf8_is_alnum(self._pa_array)
@@ -283,7 +282,12 @@ def _str_isupper(self):
283282
return self._convert_bool_result(result)
284283

285284
def _str_contains(
286-
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
285+
self,
286+
pat,
287+
case: bool = True,
288+
flags: int = 0,
289+
na: Scalar | lib.NoDefault = lib.no_default,
290+
regex: bool = True,
287291
):
288292
if flags:
289293
raise NotImplementedError(f"contains not implemented with {flags=}")
@@ -293,19 +297,25 @@ def _str_contains(
293297
else:
294298
pa_contains = pc.match_substring
295299
result = pa_contains(self._pa_array, pat, ignore_case=not case)
296-
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
297-
result = result.fill_null(na)
298-
return self._convert_bool_result(result)
300+
return self._convert_bool_result(result, na=na, method_name="contains")
299301

300302
def _str_match(
301-
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
303+
self,
304+
pat: str,
305+
case: bool = True,
306+
flags: int = 0,
307+
na: Scalar | lib.NoDefault = lib.no_default,
302308
):
303309
if not pat.startswith("^"):
304310
pat = f"^{pat}"
305311
return self._str_contains(pat, case, flags, na, regex=True)
306312

307313
def _str_fullmatch(
308-
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
314+
self,
315+
pat,
316+
case: bool = True,
317+
flags: int = 0,
318+
na: Scalar | lib.NoDefault = lib.no_default,
309319
):
310320
if not pat.endswith("$") or pat.endswith("\\$"):
311321
pat = f"{pat}$"

pandas/core/arrays/arrow/array.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2285,7 +2285,11 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
22852285
for chunk in self._pa_array.iterchunks()
22862286
]
22872287

2288-
def _convert_bool_result(self, result):
2288+
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
2289+
if na is not lib.no_default and not isna(
2290+
na
2291+
): # pyright: ignore [reportGeneralTypeIssues]
2292+
result = result.fill_null(na)
22892293
return type(self)(result)
22902294

22912295
def _convert_int_result(self, result):

pandas/core/arrays/categorical.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -2675,16 +2675,28 @@ def _replace(self, *, to_replace, value, inplace: bool = False):
26752675
# ------------------------------------------------------------------------
26762676
# String methods interface
26772677
def _str_map(
2678-
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
2678+
self, f, na_value=lib.no_default, dtype=np.dtype("object"), convert: bool = True
26792679
):
26802680
# Optimization to apply the callable `f` to the categories once
26812681
# and rebuild the result by `take`ing from the result with the codes.
26822682
# Returns the same type as the object-dtype implementation though.
2683-
from pandas.core.arrays import NumpyExtensionArray
2684-
26852683
categories = self.categories
26862684
codes = self.codes
2687-
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
2685+
if categories.dtype == "string":
2686+
result = categories.array._str_map(f, na_value, dtype) # type: ignore[attr-defined]
2687+
if (
2688+
categories.dtype.na_value is np.nan # type: ignore[union-attr]
2689+
and is_bool_dtype(dtype)
2690+
and (na_value is lib.no_default or isna(na_value))
2691+
):
2692+
# NaN propagates as False for functions with boolean return type
2693+
na_value = False
2694+
else:
2695+
from pandas.core.arrays import NumpyExtensionArray
2696+
2697+
result = NumpyExtensionArray(categories.to_numpy())._str_map(
2698+
f, na_value, dtype
2699+
)
26882700
return take_nd(result, codes, fill_value=na_value)
26892701

26902702
def _str_get_dummies(self, sep: str = "|"):

pandas/core/arrays/string_.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,11 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
377377
return cls._from_sequence(scalars, dtype=dtype)
378378

379379
def _str_map(
380-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
380+
self,
381+
f,
382+
na_value=lib.no_default,
383+
dtype: Dtype | None = None,
384+
convert: bool = True,
381385
):
382386
if self.dtype.na_value is np.nan:
383387
return self._str_map_nan_semantics(
@@ -388,7 +392,7 @@ def _str_map(
388392

389393
if dtype is None:
390394
dtype = self.dtype
391-
if na_value is None:
395+
if na_value is lib.no_default:
392396
na_value = self.dtype.na_value
393397

394398
mask = isna(self)
@@ -458,12 +462,20 @@ def _str_map_str_or_object(
458462
return lib.map_infer_mask(arr, f, mask.view("uint8"))
459463

460464
def _str_map_nan_semantics(
461-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
465+
self,
466+
f,
467+
na_value=lib.no_default,
468+
dtype: Dtype | None = None,
469+
convert: bool = True,
462470
):
463471
if dtype is None:
464472
dtype = self.dtype
465-
if na_value is None:
466-
na_value = self.dtype.na_value
473+
if na_value is lib.no_default:
474+
if is_bool_dtype(dtype):
475+
# NaN propagates as False
476+
na_value = False
477+
else:
478+
na_value = self.dtype.na_value
467479

468480
mask = isna(self)
469481
arr = np.asarray(self)
@@ -474,7 +486,8 @@ def _str_map_nan_semantics(
474486
if is_integer_dtype(dtype):
475487
na_value = 0
476488
else:
477-
na_value = True
489+
# NaN propagates as False
490+
na_value = False
478491

479492
result = lib.map_infer_mask(
480493
arr,
@@ -484,15 +497,13 @@ def _str_map_nan_semantics(
484497
na_value=na_value,
485498
dtype=np.dtype(cast(type, dtype)),
486499
)
487-
if na_value_is_na and mask.any():
500+
if na_value_is_na and is_integer_dtype(dtype) and mask.any():
488501
# TODO: we could alternatively do this check before map_infer_mask
489502
# and adjust the dtype/na_value we pass there. Which is more
490503
# performant?
491-
if is_integer_dtype(dtype):
492-
result = result.astype("float64")
493-
else:
494-
result = result.astype("object")
504+
result = result.astype("float64")
495505
result[mask] = np.nan
506+
496507
return result
497508

498509
else:

pandas/core/arrays/string_arrow.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,29 @@ def insert(self, loc: int, item) -> ArrowStringArray:
211211
raise TypeError("Scalar must be NA or str")
212212
return super().insert(loc, item)
213213

214-
def _convert_bool_result(self, values):
214+
def _convert_bool_result(self, values, na=lib.no_default, method_name=None):
215+
if na is not lib.no_default and not isna(na) and not isinstance(na, bool):
216+
# GH#59561
217+
warnings.warn(
218+
f"Allowing a non-bool 'na' in obj.str.{method_name} is deprecated "
219+
"and will raise in a future version.",
220+
FutureWarning,
221+
stacklevel=find_stack_level(),
222+
)
223+
na = bool(na)
224+
215225
if self.dtype.na_value is np.nan:
216-
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
226+
if na is lib.no_default or isna(na):
227+
# NaN propagates as False
228+
values = values.fill_null(False)
229+
else:
230+
values = values.fill_null(na)
231+
return values.to_numpy()
232+
else:
233+
if na is not lib.no_default and not isna(
234+
na
235+
): # pyright: ignore [reportGeneralTypeIssues]
236+
values = values.fill_null(na)
217237
return BooleanDtype().__from_arrow__(values)
218238

219239
def _maybe_convert_setitem_value(self, value):
@@ -309,22 +329,16 @@ def _data(self):
309329
_str_slice = ArrowStringArrayMixin._str_slice
310330

311331
def _str_contains(
312-
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
332+
self,
333+
pat,
334+
case: bool = True,
335+
flags: int = 0,
336+
na=lib.no_default,
337+
regex: bool = True,
313338
):
314339
if flags:
315340
return super()._str_contains(pat, case, flags, na, regex)
316341

317-
if not isna(na):
318-
if not isinstance(na, bool):
319-
# GH#59561
320-
warnings.warn(
321-
"Allowing a non-bool 'na' in obj.str.contains is deprecated "
322-
"and will raise in a future version.",
323-
FutureWarning,
324-
stacklevel=find_stack_level(),
325-
)
326-
na = bool(na)
327-
328342
return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)
329343

330344
def _str_replace(

0 commit comments

Comments
 (0)