Skip to content

Commit 88554d0

Browse files
String dtype: propagate NaNs as False in predicate methods (eg .str.startswith) (#59616)
1 parent 8303af3 commit 88554d0

12 files changed

+307
-146
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010

1111
import numpy as np
1212

13+
from pandas._libs import lib
1314
from pandas.compat import (
1415
pa_version_under10p1,
1516
pa_version_under11p0,
1617
pa_version_under13p0,
1718
pa_version_under17p0,
1819
)
1920

20-
from pandas.core.dtypes.missing import isna
21-
2221
if not pa_version_under10p1:
2322
import pyarrow as pa
2423
import pyarrow.compute as pc
@@ -38,7 +37,7 @@ class ArrowStringArrayMixin:
3837
def __init__(self, *args, **kwargs) -> None:
3938
raise NotImplementedError
4039

41-
def _convert_bool_result(self, result):
40+
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
4241
# Convert a bool-dtype result to the appropriate result type
4342
raise NotImplementedError
4443

@@ -212,7 +211,9 @@ def _str_removesuffix(self, suffix: str):
212211
result = pc.if_else(ends_with, removed, self._pa_array)
213212
return type(self)(result)
214213

215-
def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
214+
def _str_startswith(
215+
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
216+
):
216217
if isinstance(pat, str):
217218
result = pc.starts_with(self._pa_array, pattern=pat)
218219
else:
@@ -225,11 +226,11 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
225226

226227
for p in pat[1:]:
227228
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
228-
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
229-
result = result.fill_null(na)
230-
return self._convert_bool_result(result)
229+
return self._convert_bool_result(result, na=na, method_name="startswith")
231230

232-
def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
231+
def _str_endswith(
232+
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
233+
):
233234
if isinstance(pat, str):
234235
result = pc.ends_with(self._pa_array, pattern=pat)
235236
else:
@@ -242,9 +243,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
242243

243244
for p in pat[1:]:
244245
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
245-
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
246-
result = result.fill_null(na)
247-
return self._convert_bool_result(result)
246+
return self._convert_bool_result(result, na=na, method_name="endswith")
248247

249248
def _str_isalnum(self):
250249
result = pc.utf8_is_alnum(self._pa_array)
@@ -283,7 +282,12 @@ def _str_isupper(self):
283282
return self._convert_bool_result(result)
284283

285284
def _str_contains(
286-
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
285+
self,
286+
pat,
287+
case: bool = True,
288+
flags: int = 0,
289+
na: Scalar | lib.NoDefault = lib.no_default,
290+
regex: bool = True,
287291
):
288292
if flags:
289293
raise NotImplementedError(f"contains not implemented with {flags=}")
@@ -293,19 +297,25 @@ def _str_contains(
293297
else:
294298
pa_contains = pc.match_substring
295299
result = pa_contains(self._pa_array, pat, ignore_case=not case)
296-
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
297-
result = result.fill_null(na)
298-
return self._convert_bool_result(result)
300+
return self._convert_bool_result(result, na=na, method_name="contains")
299301

300302
def _str_match(
301-
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
303+
self,
304+
pat: str,
305+
case: bool = True,
306+
flags: int = 0,
307+
na: Scalar | lib.NoDefault = lib.no_default,
302308
):
303309
if not pat.startswith("^"):
304310
pat = f"^{pat}"
305311
return self._str_contains(pat, case, flags, na, regex=True)
306312

307313
def _str_fullmatch(
308-
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
314+
self,
315+
pat,
316+
case: bool = True,
317+
flags: int = 0,
318+
na: Scalar | lib.NoDefault = lib.no_default,
309319
):
310320
if not pat.endswith("$") or pat.endswith("\\$"):
311321
pat = f"{pat}$"

pandas/core/arrays/arrow/array.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2318,7 +2318,9 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
23182318
for chunk in self._pa_array.iterchunks()
23192319
]
23202320

2321-
def _convert_bool_result(self, result):
2321+
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
2322+
if na is not lib.no_default and not isna(na): # pyright: ignore [reportGeneralTypeIssues]
2323+
result = result.fill_null(na)
23222324
return type(self)(result)
23232325

23242326
def _convert_int_result(self, result):

pandas/core/arrays/categorical.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -2679,16 +2679,28 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
26792679
# ------------------------------------------------------------------------
26802680
# String methods interface
26812681
def _str_map(
2682-
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
2682+
self, f, na_value=lib.no_default, dtype=np.dtype("object"), convert: bool = True
26832683
):
26842684
# Optimization to apply the callable `f` to the categories once
26852685
# and rebuild the result by `take`ing from the result with the codes.
26862686
# Returns the same type as the object-dtype implementation though.
2687-
from pandas.core.arrays import NumpyExtensionArray
2688-
26892687
categories = self.categories
26902688
codes = self.codes
2691-
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
2689+
if categories.dtype == "string":
2690+
result = categories.array._str_map(f, na_value, dtype) # type: ignore[attr-defined]
2691+
if (
2692+
categories.dtype.na_value is np.nan # type: ignore[union-attr]
2693+
and is_bool_dtype(dtype)
2694+
and (na_value is lib.no_default or isna(na_value))
2695+
):
2696+
# NaN propagates as False for functions with boolean return type
2697+
na_value = False
2698+
else:
2699+
from pandas.core.arrays import NumpyExtensionArray
2700+
2701+
result = NumpyExtensionArray(categories.to_numpy())._str_map(
2702+
f, na_value, dtype
2703+
)
26922704
return take_nd(result, codes, fill_value=na_value)
26932705

26942706
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):

pandas/core/arrays/string_.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,11 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
381381
return cls._from_sequence(scalars, dtype=dtype)
382382

383383
def _str_map(
384-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
384+
self,
385+
f,
386+
na_value=lib.no_default,
387+
dtype: Dtype | None = None,
388+
convert: bool = True,
385389
):
386390
if self.dtype.na_value is np.nan:
387391
return self._str_map_nan_semantics(f, na_value=na_value, dtype=dtype)
@@ -390,7 +394,7 @@ def _str_map(
390394

391395
if dtype is None:
392396
dtype = self.dtype
393-
if na_value is None:
397+
if na_value is lib.no_default:
394398
na_value = self.dtype.na_value
395399

396400
mask = isna(self)
@@ -459,11 +463,17 @@ def _str_map_str_or_object(
459463
# -> We don't know the result type. E.g. `.get` can return anything.
460464
return lib.map_infer_mask(arr, f, mask.view("uint8"))
461465

462-
def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
466+
def _str_map_nan_semantics(
467+
self, f, na_value=lib.no_default, dtype: Dtype | None = None
468+
):
463469
if dtype is None:
464470
dtype = self.dtype
465-
if na_value is None:
466-
na_value = self.dtype.na_value
471+
if na_value is lib.no_default:
472+
if is_bool_dtype(dtype):
473+
# NaN propagates as False
474+
na_value = False
475+
else:
476+
na_value = self.dtype.na_value
467477

468478
mask = isna(self)
469479
arr = np.asarray(self)
@@ -474,7 +484,8 @@ def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
474484
if is_integer_dtype(dtype):
475485
na_value = 0
476486
else:
477-
na_value = True
487+
# NaN propagates as False
488+
na_value = False
478489

479490
result = lib.map_infer_mask(
480491
arr,
@@ -484,15 +495,13 @@ def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
484495
na_value=na_value,
485496
dtype=np.dtype(cast(type, dtype)),
486497
)
487-
if na_value_is_na and mask.any():
498+
if na_value_is_na and is_integer_dtype(dtype) and mask.any():
488499
# TODO: we could alternatively do this check before map_infer_mask
489500
# and adjust the dtype/na_value we pass there. Which is more
490501
# performant?
491-
if is_integer_dtype(dtype):
492-
result = result.astype("float64")
493-
else:
494-
result = result.astype("object")
502+
result = result.astype("float64")
495503
result[mask] = np.nan
504+
496505
return result
497506

498507
else:

pandas/core/arrays/string_arrow.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,27 @@ def insert(self, loc: int, item) -> ArrowStringArray:
219219
raise TypeError("Scalar must be NA or str")
220220
return super().insert(loc, item)
221221

222-
def _convert_bool_result(self, values):
222+
def _convert_bool_result(self, values, na=lib.no_default, method_name=None):
223+
if na is not lib.no_default and not isna(na) and not isinstance(na, bool):
224+
# GH#59561
225+
warnings.warn(
226+
f"Allowing a non-bool 'na' in obj.str.{method_name} is deprecated "
227+
"and will raise in a future version.",
228+
FutureWarning,
229+
stacklevel=find_stack_level(),
230+
)
231+
na = bool(na)
232+
223233
if self.dtype.na_value is np.nan:
224-
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
234+
if na is lib.no_default or isna(na):
235+
# NaN propagates as False
236+
values = values.fill_null(False)
237+
else:
238+
values = values.fill_null(na)
239+
return values.to_numpy()
240+
else:
241+
if na is not lib.no_default and not isna(na): # pyright: ignore [reportGeneralTypeIssues]
242+
values = values.fill_null(na)
225243
return BooleanDtype().__from_arrow__(values)
226244

227245
def _maybe_convert_setitem_value(self, value):
@@ -306,22 +324,16 @@ def astype(self, dtype, copy: bool = True):
306324
_str_slice = ArrowStringArrayMixin._str_slice
307325

308326
def _str_contains(
309-
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
327+
self,
328+
pat,
329+
case: bool = True,
330+
flags: int = 0,
331+
na=lib.no_default,
332+
regex: bool = True,
310333
):
311334
if flags:
312335
return super()._str_contains(pat, case, flags, na, regex)
313336

314-
if not isna(na):
315-
if not isinstance(na, bool):
316-
# GH#59561
317-
warnings.warn(
318-
"Allowing a non-bool 'na' in obj.str.contains is deprecated "
319-
"and will raise in a future version.",
320-
FutureWarning,
321-
stacklevel=find_stack_level(),
322-
)
323-
na = bool(na)
324-
325337
return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)
326338

327339
def _str_replace(

pandas/core/strings/accessor.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,12 @@ def join(self, sep: str):
12251225

12261226
@forbid_nonstring_types(["bytes"])
12271227
def contains(
1228-
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
1228+
self,
1229+
pat,
1230+
case: bool = True,
1231+
flags: int = 0,
1232+
na=lib.no_default,
1233+
regex: bool = True,
12291234
):
12301235
r"""
12311236
Test if pattern or regex is contained within a string of a Series or Index.
@@ -1243,8 +1248,9 @@ def contains(
12431248
Flags to pass through to the re module, e.g. re.IGNORECASE.
12441249
na : scalar, optional
12451250
Fill value for missing values. The default depends on dtype of the
1246-
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
1247-
``pandas.NA`` is used.
1251+
array. For object-dtype, ``numpy.nan`` is used. For the nullable
1252+
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
1253+
``False`` is used.
12481254
regex : bool, default True
12491255
If True, assumes the pat is a regular expression.
12501256
@@ -1362,7 +1368,7 @@ def contains(
13621368
return self._wrap_result(result, fill_value=na, returns_string=False)
13631369

13641370
@forbid_nonstring_types(["bytes"])
1365-
def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
1371+
def match(self, pat: str, case: bool = True, flags: int = 0, na=lib.no_default):
13661372
"""
13671373
Determine if each string starts with a match of a regular expression.
13681374
@@ -1376,8 +1382,9 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
13761382
Regex module flags, e.g. re.IGNORECASE.
13771383
na : scalar, optional
13781384
Fill value for missing values. The default depends on dtype of the
1379-
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
1380-
``pandas.NA`` is used.
1385+
array. For object-dtype, ``numpy.nan`` is used. For the nullable
1386+
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
1387+
``False`` is used.
13811388
13821389
Returns
13831390
-------
@@ -1406,7 +1413,7 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
14061413
return self._wrap_result(result, fill_value=na, returns_string=False)
14071414

14081415
@forbid_nonstring_types(["bytes"])
1409-
def fullmatch(self, pat, case: bool = True, flags: int = 0, na=None):
1416+
def fullmatch(self, pat, case: bool = True, flags: int = 0, na=lib.no_default):
14101417
"""
14111418
Determine if each string entirely matches a regular expression.
14121419
@@ -1420,8 +1427,9 @@ def fullmatch(self, pat, case: bool = True, flags: int = 0, na=None):
14201427
Regex module flags, e.g. re.IGNORECASE.
14211428
na : scalar, optional
14221429
Fill value for missing values. The default depends on dtype of the
1423-
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
1424-
``pandas.NA`` is used.
1430+
array. For object-dtype, ``numpy.nan`` is used. For the nullable
1431+
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
1432+
``False`` is used.
14251433
14261434
Returns
14271435
-------
@@ -2612,7 +2620,7 @@ def count(self, pat, flags: int = 0):
26122620

26132621
@forbid_nonstring_types(["bytes"])
26142622
def startswith(
2615-
self, pat: str | tuple[str, ...], na: Scalar | None = None
2623+
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
26162624
) -> Series | Index:
26172625
"""
26182626
Test if the start of each string element matches a pattern.
@@ -2624,10 +2632,11 @@ def startswith(
26242632
pat : str or tuple[str, ...]
26252633
Character sequence or tuple of strings. Regular expressions are not
26262634
accepted.
2627-
na : object, default NaN
2635+
na : scalar, optional
26282636
Object shown if element tested is not a string. The default depends
26292637
on dtype of the array. For object-dtype, ``numpy.nan`` is used.
2630-
For ``StringDtype``, ``pandas.NA`` is used.
2638+
For the nullable ``StringDtype``, ``pandas.NA`` is used.
2639+
For the ``"str"`` dtype, ``False`` is used.
26312640
26322641
Returns
26332642
-------
@@ -2682,7 +2691,7 @@ def startswith(
26822691

26832692
@forbid_nonstring_types(["bytes"])
26842693
def endswith(
2685-
self, pat: str | tuple[str, ...], na: Scalar | None = None
2694+
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
26862695
) -> Series | Index:
26872696
"""
26882697
Test if the end of each string element matches a pattern.
@@ -2694,10 +2703,11 @@ def endswith(
26942703
pat : str or tuple[str, ...]
26952704
Character sequence or tuple of strings. Regular expressions are not
26962705
accepted.
2697-
na : object, default NaN
2706+
na : scalar, optional
26982707
Object shown if element tested is not a string. The default depends
26992708
on dtype of the array. For object-dtype, ``numpy.nan`` is used.
2700-
For ``StringDtype``, ``pandas.NA`` is used.
2709+
For the nullable ``StringDtype``, ``pandas.NA`` is used.
2710+
For the ``"str"`` dtype, ``False`` is used.
27012711
27022712
Returns
27032713
-------

0 commit comments

Comments
 (0)