2
2
from functools import wraps
3
3
import re
4
4
import textwrap
5
- from typing import TYPE_CHECKING , Any , Callable , Dict , List
5
+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Type , Union
6
6
import warnings
7
7
8
8
import numpy as np
@@ -142,7 +142,7 @@ def _map_stringarray(
142
142
The value to use for missing values. By default, this is
143
143
the original value (NA).
144
144
dtype : Dtype
145
- The result dtype to use. Specifying this aviods an intermediate
145
+ The result dtype to use. Specifying this avoids an intermediate
146
146
object-dtype allocation.
147
147
148
148
Returns
@@ -152,14 +152,20 @@ def _map_stringarray(
152
152
an ndarray.
153
153
154
154
"""
155
- from pandas .arrays import IntegerArray , StringArray
155
+ from pandas .arrays import IntegerArray , StringArray , BooleanArray
156
156
157
157
mask = isna (arr )
158
158
159
159
assert isinstance (arr , StringArray )
160
160
arr = np .asarray (arr )
161
161
162
- if is_integer_dtype (dtype ):
162
+ if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
163
+ constructor : Union [Type [IntegerArray ], Type [BooleanArray ]]
164
+ if is_integer_dtype (dtype ):
165
+ constructor = IntegerArray
166
+ else :
167
+ constructor = BooleanArray
168
+
163
169
na_value_is_na = isna (na_value )
164
170
if na_value_is_na :
165
171
na_value = 1
@@ -169,21 +175,20 @@ def _map_stringarray(
169
175
mask .view ("uint8" ),
170
176
convert = False ,
171
177
na_value = na_value ,
172
- dtype = np .dtype ("int64" ),
178
+ dtype = np .dtype (dtype ),
173
179
)
174
180
175
181
if not na_value_is_na :
176
182
mask [:] = False
177
183
178
- return IntegerArray (result , mask )
184
+ return constructor (result , mask )
179
185
180
186
elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
181
187
# i.e. StringDtype
182
188
result = lib .map_infer_mask (
183
189
arr , func , mask .view ("uint8" ), convert = False , na_value = na_value
184
190
)
185
191
return StringArray (result )
186
- # TODO: BooleanArray
187
192
else :
188
193
# This is when the result type is object. We reach this when
189
194
# -> We know the result type is truly object (e.g. .encode returns bytes
@@ -299,7 +304,7 @@ def str_count(arr, pat, flags=0):
299
304
"""
300
305
regex = re .compile (pat , flags = flags )
301
306
f = lambda x : len (regex .findall (x ))
302
- return _na_map (f , arr , dtype = int )
307
+ return _na_map (f , arr , dtype = "int64" )
303
308
304
309
305
310
def str_contains (arr , pat , case = True , flags = 0 , na = np .nan , regex = True ):
@@ -1365,7 +1370,7 @@ def str_find(arr, sub, start=0, end=None, side="left"):
1365
1370
else :
1366
1371
f = lambda x : getattr (x , method )(sub , start , end )
1367
1372
1368
- return _na_map (f , arr , dtype = int )
1373
+ return _na_map (f , arr , dtype = "int64" )
1369
1374
1370
1375
1371
1376
def str_index (arr , sub , start = 0 , end = None , side = "left" ):
@@ -1385,7 +1390,7 @@ def str_index(arr, sub, start=0, end=None, side="left"):
1385
1390
else :
1386
1391
f = lambda x : getattr (x , method )(sub , start , end )
1387
1392
1388
- return _na_map (f , arr , dtype = int )
1393
+ return _na_map (f , arr , dtype = "int64" )
1389
1394
1390
1395
1391
1396
def str_pad (arr , width , side = "left" , fillchar = " " ):
@@ -3210,7 +3215,7 @@ def rindex(self, sub, start=0, end=None):
3210
3215
len ,
3211
3216
docstring = _shared_docs ["len" ],
3212
3217
forbidden_types = None ,
3213
- dtype = int ,
3218
+ dtype = "int64" ,
3214
3219
returns_string = False ,
3215
3220
)
3216
3221
0 commit comments