2
2
from functools import wraps
3
3
import re
4
4
import textwrap
5
- from typing import Dict , List
5
+ from typing import TYPE_CHECKING , Any , Callable , Dict , List
6
6
import warnings
7
7
8
8
import numpy as np
15
15
ensure_object ,
16
16
is_bool_dtype ,
17
17
is_categorical_dtype ,
18
+ is_extension_array_dtype ,
18
19
is_integer ,
20
+ is_integer_dtype ,
19
21
is_list_like ,
22
+ is_object_dtype ,
20
23
is_re ,
21
24
is_scalar ,
25
+ is_string_dtype ,
22
26
)
23
27
from pandas .core .dtypes .generic import (
24
28
ABCDataFrame ,
28
32
)
29
33
from pandas .core .dtypes .missing import isna
30
34
35
+ from pandas ._typing import ArrayLike , Dtype
31
36
from pandas .core .algorithms import take_1d
32
37
from pandas .core .base import NoNewAttributesMixin
33
38
import pandas .core .common as com
39
+ from pandas .core .construction import extract_array
40
+
41
+ if TYPE_CHECKING :
42
+ from pandas .arrays import StringArray
34
43
35
44
_cpython_optimized_encoders = (
36
45
"utf-8" ,
@@ -109,10 +118,79 @@ def cat_safe(list_of_columns: List, sep: str):
109
118
110
119
def _na_map (f , arr , na_result = np .nan , dtype = object ):
111
120
# should really _check_ for NA
112
- return _map (f , arr , na_mask = True , na_value = na_result , dtype = dtype )
121
+ if is_extension_array_dtype (arr .dtype ):
122
+ # just StringDtype
123
+ arr = extract_array (arr )
124
+ return _map_stringarray (f , arr , na_value = na_result , dtype = dtype )
125
+ return _map_object (f , arr , na_mask = True , na_value = na_result , dtype = dtype )
126
+
127
+
128
+ def _map_stringarray (
129
+ func : Callable [[str ], Any ], arr : "StringArray" , na_value : Any , dtype : Dtype
130
+ ) -> ArrayLike :
131
+ """
132
+ Map a callable over valid elements of a StringArrray.
133
+
134
+ Parameters
135
+ ----------
136
+ func : Callable[[str], Any]
137
+ Apply to each valid element.
138
+ arr : StringArray
139
+ na_value : Any
140
+ The value to use for missing values. By default, this is
141
+ the original value (NA).
142
+ dtype : Dtype
143
+ The result dtype to use. Specifying this aviods an intermediate
144
+ object-dtype allocation.
145
+
146
+ Returns
147
+ -------
148
+ ArrayLike
149
+ An ExtensionArray for integer or string dtypes, otherwise
150
+ an ndarray.
151
+
152
+ """
153
+ from pandas .arrays import IntegerArray , StringArray
154
+
155
+ mask = isna (arr )
156
+
157
+ assert isinstance (arr , StringArray )
158
+ arr = np .asarray (arr )
159
+
160
+ if is_integer_dtype (dtype ):
161
+ na_value_is_na = isna (na_value )
162
+ if na_value_is_na :
163
+ na_value = 1
164
+ result = lib .map_infer_mask (
165
+ arr ,
166
+ func ,
167
+ mask .view ("uint8" ),
168
+ convert = False ,
169
+ na_value = na_value ,
170
+ dtype = np .dtype ("int64" ),
171
+ )
172
+
173
+ if not na_value_is_na :
174
+ mask [:] = False
175
+
176
+ return IntegerArray (result , mask )
177
+
178
+ elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
179
+ # i.e. StringDtype
180
+ result = lib .map_infer_mask (
181
+ arr , func , mask .view ("uint8" ), convert = False , na_value = na_value
182
+ )
183
+ return StringArray (result )
184
+ # TODO: BooleanArray
185
+ else :
186
+ # This is when the result type is object. We reach this when
187
+ # -> We know the result type is truly object (e.g. .encode returns bytes
188
+ # or .findall returns a list).
189
+ # -> We don't know the result type. E.g. `.get` can return anything.
190
+ return lib .map_infer_mask (arr , func , mask .view ("uint8" ))
113
191
114
192
115
- def _map (f , arr , na_mask = False , na_value = np .nan , dtype = object ):
193
+ def _map_object (f , arr , na_mask = False , na_value = np .nan , dtype = object ):
116
194
if not len (arr ):
117
195
return np .ndarray (0 , dtype = dtype )
118
196
@@ -143,7 +221,7 @@ def g(x):
143
221
except (TypeError , AttributeError ):
144
222
return na_value
145
223
146
- return _map (g , arr , dtype = dtype )
224
+ return _map_object (g , arr , dtype = dtype )
147
225
if na_value is not np .nan :
148
226
np .putmask (result , mask , na_value )
149
227
if result .dtype == object :
@@ -634,7 +712,7 @@ def str_replace(arr, pat, repl, n=-1, case=None, flags=0, regex=True):
634
712
raise ValueError ("Cannot use a callable replacement when regex=False" )
635
713
f = lambda x : x .replace (pat , repl , n )
636
714
637
- return _na_map (f , arr )
715
+ return _na_map (f , arr , dtype = str )
638
716
639
717
640
718
def str_repeat (arr , repeats ):
@@ -685,7 +763,7 @@ def scalar_rep(x):
685
763
except TypeError :
686
764
return str .__mul__ (x , repeats )
687
765
688
- return _na_map (scalar_rep , arr )
766
+ return _na_map (scalar_rep , arr , dtype = str )
689
767
else :
690
768
691
769
def rep (x , r ):
@@ -1150,7 +1228,7 @@ def str_join(arr, sep):
1150
1228
4 NaN
1151
1229
dtype: object
1152
1230
"""
1153
- return _na_map (sep .join , arr )
1231
+ return _na_map (sep .join , arr , dtype = str )
1154
1232
1155
1233
1156
1234
def str_findall (arr , pat , flags = 0 ):
@@ -1381,7 +1459,7 @@ def str_pad(arr, width, side="left", fillchar=" "):
1381
1459
else : # pragma: no cover
1382
1460
raise ValueError ("Invalid side" )
1383
1461
1384
- return _na_map (f , arr )
1462
+ return _na_map (f , arr , dtype = str )
1385
1463
1386
1464
1387
1465
def str_split (arr , pat = None , n = None ):
@@ -1487,7 +1565,7 @@ def str_slice(arr, start=None, stop=None, step=None):
1487
1565
"""
1488
1566
obj = slice (start , stop , step )
1489
1567
f = lambda x : x [obj ]
1490
- return _na_map (f , arr )
1568
+ return _na_map (f , arr , dtype = str )
1491
1569
1492
1570
1493
1571
def str_slice_replace (arr , start = None , stop = None , repl = None ):
@@ -1578,7 +1656,7 @@ def f(x):
1578
1656
y += x [local_stop :]
1579
1657
return y
1580
1658
1581
- return _na_map (f , arr )
1659
+ return _na_map (f , arr , dtype = str )
1582
1660
1583
1661
1584
1662
def str_strip (arr , to_strip = None , side = "both" ):
@@ -1603,7 +1681,7 @@ def str_strip(arr, to_strip=None, side="both"):
1603
1681
f = lambda x : x .rstrip (to_strip )
1604
1682
else : # pragma: no cover
1605
1683
raise ValueError ("Invalid side" )
1606
- return _na_map (f , arr )
1684
+ return _na_map (f , arr , dtype = str )
1607
1685
1608
1686
1609
1687
def str_wrap (arr , width , ** kwargs ):
@@ -1667,7 +1745,7 @@ def str_wrap(arr, width, **kwargs):
1667
1745
1668
1746
tw = textwrap .TextWrapper (** kwargs )
1669
1747
1670
- return _na_map (lambda s : "\n " .join (tw .wrap (s )), arr )
1748
+ return _na_map (lambda s : "\n " .join (tw .wrap (s )), arr , dtype = str )
1671
1749
1672
1750
1673
1751
def str_translate (arr , table ):
@@ -1687,7 +1765,7 @@ def str_translate(arr, table):
1687
1765
-------
1688
1766
Series or Index
1689
1767
"""
1690
- return _na_map (lambda x : x .translate (table ), arr )
1768
+ return _na_map (lambda x : x .translate (table ), arr , dtype = str )
1691
1769
1692
1770
1693
1771
def str_get (arr , i ):
@@ -3025,7 +3103,7 @@ def normalize(self, form):
3025
3103
import unicodedata
3026
3104
3027
3105
f = lambda x : unicodedata .normalize (form , x )
3028
- result = _na_map (f , self ._parent )
3106
+ result = _na_map (f , self ._parent , dtype = str )
3029
3107
return self ._wrap_result (result )
3030
3108
3031
3109
_shared_docs [
@@ -3223,31 +3301,37 @@ def rindex(self, sub, start=0, end=None):
3223
3301
lambda x : x .lower (),
3224
3302
name = "lower" ,
3225
3303
docstring = _shared_docs ["casemethods" ] % _doc_args ["lower" ],
3304
+ dtype = str ,
3226
3305
)
3227
3306
upper = _noarg_wrapper (
3228
3307
lambda x : x .upper (),
3229
3308
name = "upper" ,
3230
3309
docstring = _shared_docs ["casemethods" ] % _doc_args ["upper" ],
3310
+ dtype = str ,
3231
3311
)
3232
3312
title = _noarg_wrapper (
3233
3313
lambda x : x .title (),
3234
3314
name = "title" ,
3235
3315
docstring = _shared_docs ["casemethods" ] % _doc_args ["title" ],
3316
+ dtype = str ,
3236
3317
)
3237
3318
capitalize = _noarg_wrapper (
3238
3319
lambda x : x .capitalize (),
3239
3320
name = "capitalize" ,
3240
3321
docstring = _shared_docs ["casemethods" ] % _doc_args ["capitalize" ],
3322
+ dtype = str ,
3241
3323
)
3242
3324
swapcase = _noarg_wrapper (
3243
3325
lambda x : x .swapcase (),
3244
3326
name = "swapcase" ,
3245
3327
docstring = _shared_docs ["casemethods" ] % _doc_args ["swapcase" ],
3328
+ dtype = str ,
3246
3329
)
3247
3330
casefold = _noarg_wrapper (
3248
3331
lambda x : x .casefold (),
3249
3332
name = "casefold" ,
3250
3333
docstring = _shared_docs ["casemethods" ] % _doc_args ["casefold" ],
3334
+ dtype = str ,
3251
3335
)
3252
3336
3253
3337
_shared_docs [
0 commit comments