Skip to content

Commit 6b1f73e

Browse files
Refactor string methods for StringArray + return IntegerArray for numeric results (pandas-dev#29640)
1 parent fe1803d commit 6b1f73e

File tree

5 files changed

+201
-27
lines changed

5 files changed

+201
-27
lines changed

doc/source/user_guide/text.rst

+35-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Text Data Types
1313

1414
.. versionadded:: 1.0.0
1515

16-
There are two main ways to store text data
16+
There are two ways to store text data in pandas:
1717

1818
1. ``object`` -dtype NumPy array.
1919
2. :class:`StringDtype` extension type.
@@ -63,7 +63,40 @@ Or ``astype`` after the ``Series`` or ``DataFrame`` is created
6363
s
6464
s.astype("string")
6565
66-
Everything that follows in the rest of this document applies equally to
66+
.. _text.differences:
67+
68+
Behavior differences
69+
^^^^^^^^^^^^^^^^^^^^
70+
71+
These are places where the behavior of ``StringDtype`` objects differ from
72+
``object`` dtype
73+
74+
l. For ``StringDtype``, :ref:`string accessor methods<api.series.str>`
75+
that return **numeric** output will always return a nullable integer dtype,
76+
rather than either int or float dtype, depending on the presence of NA values.
77+
78+
.. ipython:: python
79+
80+
s = pd.Series(["a", None, "b"], dtype="string")
81+
s
82+
s.str.count("a")
83+
s.dropna().str.count("a")
84+
85+
Both outputs are ``Int64`` dtype. Compare that with object-dtype
86+
87+
.. ipython:: python
88+
89+
s.astype(object).str.count("a")
90+
s.astype(object).dropna().str.count("a")
91+
92+
When NA values are present, the output dtype is float64.
93+
94+
2. Some string methods, like :meth:`Series.str.decode` are not available
95+
on ``StringArray`` because ``StringArray`` only holds strings, not
96+
bytes.
97+
98+
99+
Everything else that follows in the rest of this document applies equally to
67100
``string`` and ``object`` dtype.
68101

69102
.. _text.string_methods:

doc/source/whatsnew/v1.0.0.rst

+9-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Previously, strings were typically stored in object-dtype NumPy arrays.
6363
``StringDtype`` is currently considered experimental. The implementation
6464
and parts of the API may change without warning.
6565

66-
The text extension type solves several issues with object-dtype NumPy arrays:
66+
The ``'string'`` extension type solves several issues with object-dtype NumPy arrays:
6767

6868
1. You can accidentally store a *mixture* of strings and non-strings in an
6969
``object`` dtype array. A ``StringArray`` can only store strings.
@@ -88,9 +88,17 @@ You can use the alias ``"string"`` as well.
8888
The usual string accessor methods work. Where appropriate, the return type
8989
of the Series or columns of a DataFrame will also have string dtype.
9090

91+
.. ipython:: python
92+
9193
s.str.upper()
9294
s.str.split('b', expand=True).dtypes
9395
96+
String accessor methods returning integers will return a value with :class:`Int64Dtype`
97+
98+
.. ipython:: python
99+
100+
s.str.count("a")
101+
94102
We recommend explicitly using the ``string`` data type when working with strings.
95103
See :ref:`text.types` for more.
96104

pandas/_libs/lib.pyx

+20-4
Original file line numberDiff line numberDiff line change
@@ -2208,31 +2208,47 @@ def maybe_convert_objects(ndarray[object] objects, bint try_float=0,
22082208
return objects
22092209

22102210

2211+
_no_default = object()
2212+
2213+
22112214
@cython.boundscheck(False)
22122215
@cython.wraparound(False)
2213-
def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=1):
2216+
def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=1,
2217+
object na_value=_no_default, object dtype=object):
22142218
"""
22152219
Substitute for np.vectorize with pandas-friendly dtype inference
22162220
22172221
Parameters
22182222
----------
22192223
arr : ndarray
22202224
f : function
2225+
mask : ndarray
2226+
uint8 dtype ndarray indicating values not to apply `f` to.
2227+
convert : bool, default True
2228+
Whether to call `maybe_convert_objects` on the resulting ndarray
2229+
na_value : Any, optional
2230+
The result value to use for masked values. By default, the
2231+
input value is used
2232+
dtype : numpy.dtype
2233+
The numpy dtype to use for the result ndarray.
22212234
22222235
Returns
22232236
-------
22242237
mapped : ndarray
22252238
"""
22262239
cdef:
22272240
Py_ssize_t i, n
2228-
ndarray[object] result
2241+
ndarray result
22292242
object val
22302243

22312244
n = len(arr)
2232-
result = np.empty(n, dtype=object)
2245+
result = np.empty(n, dtype=dtype)
22332246
for i in range(n):
22342247
if mask[i]:
2235-
val = arr[i]
2248+
if na_value is _no_default:
2249+
val = arr[i]
2250+
else:
2251+
val = na_value
22362252
else:
22372253
val = f(arr[i])
22382254

pandas/core/strings.py

+98-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import wraps
33
import re
44
import textwrap
5-
from typing import Dict, List
5+
from typing import TYPE_CHECKING, Any, Callable, Dict, List
66
import warnings
77

88
import numpy as np
@@ -15,10 +15,14 @@
1515
ensure_object,
1616
is_bool_dtype,
1717
is_categorical_dtype,
18+
is_extension_array_dtype,
1819
is_integer,
20+
is_integer_dtype,
1921
is_list_like,
22+
is_object_dtype,
2023
is_re,
2124
is_scalar,
25+
is_string_dtype,
2226
)
2327
from pandas.core.dtypes.generic import (
2428
ABCDataFrame,
@@ -28,9 +32,14 @@
2832
)
2933
from pandas.core.dtypes.missing import isna
3034

35+
from pandas._typing import ArrayLike, Dtype
3136
from pandas.core.algorithms import take_1d
3237
from pandas.core.base import NoNewAttributesMixin
3338
import pandas.core.common as com
39+
from pandas.core.construction import extract_array
40+
41+
if TYPE_CHECKING:
42+
from pandas.arrays import StringArray
3443

3544
_cpython_optimized_encoders = (
3645
"utf-8",
@@ -109,10 +118,79 @@ def cat_safe(list_of_columns: List, sep: str):
109118

110119
def _na_map(f, arr, na_result=np.nan, dtype=object):
111120
# should really _check_ for NA
112-
return _map(f, arr, na_mask=True, na_value=na_result, dtype=dtype)
121+
if is_extension_array_dtype(arr.dtype):
122+
# just StringDtype
123+
arr = extract_array(arr)
124+
return _map_stringarray(f, arr, na_value=na_result, dtype=dtype)
125+
return _map_object(f, arr, na_mask=True, na_value=na_result, dtype=dtype)
126+
127+
128+
def _map_stringarray(
129+
func: Callable[[str], Any], arr: "StringArray", na_value: Any, dtype: Dtype
130+
) -> ArrayLike:
131+
"""
132+
Map a callable over valid elements of a StringArrray.
133+
134+
Parameters
135+
----------
136+
func : Callable[[str], Any]
137+
Apply to each valid element.
138+
arr : StringArray
139+
na_value : Any
140+
The value to use for missing values. By default, this is
141+
the original value (NA).
142+
dtype : Dtype
143+
The result dtype to use. Specifying this aviods an intermediate
144+
object-dtype allocation.
145+
146+
Returns
147+
-------
148+
ArrayLike
149+
An ExtensionArray for integer or string dtypes, otherwise
150+
an ndarray.
151+
152+
"""
153+
from pandas.arrays import IntegerArray, StringArray
154+
155+
mask = isna(arr)
156+
157+
assert isinstance(arr, StringArray)
158+
arr = np.asarray(arr)
159+
160+
if is_integer_dtype(dtype):
161+
na_value_is_na = isna(na_value)
162+
if na_value_is_na:
163+
na_value = 1
164+
result = lib.map_infer_mask(
165+
arr,
166+
func,
167+
mask.view("uint8"),
168+
convert=False,
169+
na_value=na_value,
170+
dtype=np.dtype("int64"),
171+
)
172+
173+
if not na_value_is_na:
174+
mask[:] = False
175+
176+
return IntegerArray(result, mask)
177+
178+
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
179+
# i.e. StringDtype
180+
result = lib.map_infer_mask(
181+
arr, func, mask.view("uint8"), convert=False, na_value=na_value
182+
)
183+
return StringArray(result)
184+
# TODO: BooleanArray
185+
else:
186+
# This is when the result type is object. We reach this when
187+
# -> We know the result type is truly object (e.g. .encode returns bytes
188+
# or .findall returns a list).
189+
# -> We don't know the result type. E.g. `.get` can return anything.
190+
return lib.map_infer_mask(arr, func, mask.view("uint8"))
113191

114192

115-
def _map(f, arr, na_mask=False, na_value=np.nan, dtype=object):
193+
def _map_object(f, arr, na_mask=False, na_value=np.nan, dtype=object):
116194
if not len(arr):
117195
return np.ndarray(0, dtype=dtype)
118196

@@ -143,7 +221,7 @@ def g(x):
143221
except (TypeError, AttributeError):
144222
return na_value
145223

146-
return _map(g, arr, dtype=dtype)
224+
return _map_object(g, arr, dtype=dtype)
147225
if na_value is not np.nan:
148226
np.putmask(result, mask, na_value)
149227
if result.dtype == object:
@@ -634,7 +712,7 @@ def str_replace(arr, pat, repl, n=-1, case=None, flags=0, regex=True):
634712
raise ValueError("Cannot use a callable replacement when regex=False")
635713
f = lambda x: x.replace(pat, repl, n)
636714

637-
return _na_map(f, arr)
715+
return _na_map(f, arr, dtype=str)
638716

639717

640718
def str_repeat(arr, repeats):
@@ -685,7 +763,7 @@ def scalar_rep(x):
685763
except TypeError:
686764
return str.__mul__(x, repeats)
687765

688-
return _na_map(scalar_rep, arr)
766+
return _na_map(scalar_rep, arr, dtype=str)
689767
else:
690768

691769
def rep(x, r):
@@ -1150,7 +1228,7 @@ def str_join(arr, sep):
11501228
4 NaN
11511229
dtype: object
11521230
"""
1153-
return _na_map(sep.join, arr)
1231+
return _na_map(sep.join, arr, dtype=str)
11541232

11551233

11561234
def str_findall(arr, pat, flags=0):
@@ -1381,7 +1459,7 @@ def str_pad(arr, width, side="left", fillchar=" "):
13811459
else: # pragma: no cover
13821460
raise ValueError("Invalid side")
13831461

1384-
return _na_map(f, arr)
1462+
return _na_map(f, arr, dtype=str)
13851463

13861464

13871465
def str_split(arr, pat=None, n=None):
@@ -1487,7 +1565,7 @@ def str_slice(arr, start=None, stop=None, step=None):
14871565
"""
14881566
obj = slice(start, stop, step)
14891567
f = lambda x: x[obj]
1490-
return _na_map(f, arr)
1568+
return _na_map(f, arr, dtype=str)
14911569

14921570

14931571
def str_slice_replace(arr, start=None, stop=None, repl=None):
@@ -1578,7 +1656,7 @@ def f(x):
15781656
y += x[local_stop:]
15791657
return y
15801658

1581-
return _na_map(f, arr)
1659+
return _na_map(f, arr, dtype=str)
15821660

15831661

15841662
def str_strip(arr, to_strip=None, side="both"):
@@ -1603,7 +1681,7 @@ def str_strip(arr, to_strip=None, side="both"):
16031681
f = lambda x: x.rstrip(to_strip)
16041682
else: # pragma: no cover
16051683
raise ValueError("Invalid side")
1606-
return _na_map(f, arr)
1684+
return _na_map(f, arr, dtype=str)
16071685

16081686

16091687
def str_wrap(arr, width, **kwargs):
@@ -1667,7 +1745,7 @@ def str_wrap(arr, width, **kwargs):
16671745

16681746
tw = textwrap.TextWrapper(**kwargs)
16691747

1670-
return _na_map(lambda s: "\n".join(tw.wrap(s)), arr)
1748+
return _na_map(lambda s: "\n".join(tw.wrap(s)), arr, dtype=str)
16711749

16721750

16731751
def str_translate(arr, table):
@@ -1687,7 +1765,7 @@ def str_translate(arr, table):
16871765
-------
16881766
Series or Index
16891767
"""
1690-
return _na_map(lambda x: x.translate(table), arr)
1768+
return _na_map(lambda x: x.translate(table), arr, dtype=str)
16911769

16921770

16931771
def str_get(arr, i):
@@ -3025,7 +3103,7 @@ def normalize(self, form):
30253103
import unicodedata
30263104

30273105
f = lambda x: unicodedata.normalize(form, x)
3028-
result = _na_map(f, self._parent)
3106+
result = _na_map(f, self._parent, dtype=str)
30293107
return self._wrap_result(result)
30303108

30313109
_shared_docs[
@@ -3223,31 +3301,37 @@ def rindex(self, sub, start=0, end=None):
32233301
lambda x: x.lower(),
32243302
name="lower",
32253303
docstring=_shared_docs["casemethods"] % _doc_args["lower"],
3304+
dtype=str,
32263305
)
32273306
upper = _noarg_wrapper(
32283307
lambda x: x.upper(),
32293308
name="upper",
32303309
docstring=_shared_docs["casemethods"] % _doc_args["upper"],
3310+
dtype=str,
32313311
)
32323312
title = _noarg_wrapper(
32333313
lambda x: x.title(),
32343314
name="title",
32353315
docstring=_shared_docs["casemethods"] % _doc_args["title"],
3316+
dtype=str,
32363317
)
32373318
capitalize = _noarg_wrapper(
32383319
lambda x: x.capitalize(),
32393320
name="capitalize",
32403321
docstring=_shared_docs["casemethods"] % _doc_args["capitalize"],
3322+
dtype=str,
32413323
)
32423324
swapcase = _noarg_wrapper(
32433325
lambda x: x.swapcase(),
32443326
name="swapcase",
32453327
docstring=_shared_docs["casemethods"] % _doc_args["swapcase"],
3328+
dtype=str,
32463329
)
32473330
casefold = _noarg_wrapper(
32483331
lambda x: x.casefold(),
32493332
name="casefold",
32503333
docstring=_shared_docs["casemethods"] % _doc_args["casefold"],
3334+
dtype=str,
32513335
)
32523336

32533337
_shared_docs[

0 commit comments

Comments
 (0)