Skip to content

Commit 124359f

Browse files
jbrockmendelWillAyd
authored andcommitted
REF (string): de-duplicate str_map_nan_semantics (#59464)
REF: de-duplicate str_map_nan_semantics
1 parent 4b8b315 commit 124359f

File tree

2 files changed

+5
-46
lines changed

2 files changed

+5
-46
lines changed

pandas/core/arrays/string_.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _str_map(
391391
return constructor(result, mask)
392392

393393
else:
394-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
394+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask)
395395

396396
def _str_map_str_or_object(
397397
self,
@@ -400,7 +400,6 @@ def _str_map_str_or_object(
400400
arr: np.ndarray,
401401
f,
402402
mask: npt.NDArray[np.bool_],
403-
convert: bool,
404403
):
405404
# _str_map helper for case where dtype is either string dtype or object
406405
if is_string_dtype(dtype) and not is_object_dtype(dtype):
@@ -434,7 +433,6 @@ def _str_map_nan_semantics(
434433

435434
mask = isna(self)
436435
arr = np.asarray(self)
437-
convert = convert and not np.all(mask)
438436

439437
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
440438
na_value_is_na = isna(na_value)
@@ -453,6 +451,9 @@ def _str_map_nan_semantics(
453451
dtype=np.dtype(cast(type, dtype)),
454452
)
455453
if na_value_is_na and mask.any():
454+
# TODO: we could alternatively do this check before map_infer_mask
455+
# and adjust the dtype/na_value we pass there. Which is more
456+
# performant?
456457
if is_integer_dtype(dtype):
457458
result = result.astype("float64")
458459
else:
@@ -461,7 +462,7 @@ def _str_map_nan_semantics(
461462
return result
462463

463464
else:
464-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
465+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask)
465466

466467

467468
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is

pandas/core/arrays/string_arrow.py

-42
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
TYPE_CHECKING,
88
Callable,
99
Union,
10-
cast,
1110
)
1211
import warnings
1312

@@ -24,8 +23,6 @@
2423
from pandas.util._exceptions import find_stack_level
2524

2625
from pandas.core.dtypes.common import (
27-
is_bool_dtype,
28-
is_integer_dtype,
2926
is_scalar,
3027
pandas_dtype,
3128
)
@@ -285,45 +282,6 @@ def _data(self):
285282

286283
_str_map = BaseStringArray._str_map
287284

288-
def _str_map_nan_semantics(
289-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
290-
):
291-
if dtype is None:
292-
dtype = self.dtype
293-
if na_value is None:
294-
na_value = self.dtype.na_value
295-
296-
mask = isna(self)
297-
arr = np.asarray(self)
298-
299-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
300-
if is_integer_dtype(dtype):
301-
na_value = np.nan
302-
else:
303-
na_value = False
304-
305-
dtype = np.dtype(cast(type, dtype))
306-
if mask.any():
307-
# numpy int/bool dtypes cannot hold NaNs so we must convert to
308-
# float64 for int (to match maybe_convert_objects) or
309-
# object for bool (again to match maybe_convert_objects)
310-
if is_integer_dtype(dtype):
311-
dtype = np.dtype("float64")
312-
else:
313-
dtype = np.dtype(object)
314-
result = lib.map_infer_mask(
315-
arr,
316-
f,
317-
mask.view("uint8"),
318-
convert=False,
319-
na_value=na_value,
320-
dtype=dtype,
321-
)
322-
return result
323-
324-
else:
325-
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
326-
327285
def _str_contains(
328286
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
329287
):

0 commit comments

Comments
 (0)