2
2
3
3
from pandas .compat import zip
4
4
from pandas .core .common import (isnull , _values_from_object , is_bool_dtype , is_list_like ,
5
- is_categorical_dtype , is_object_dtype )
5
+ is_categorical_dtype , is_object_dtype , take_1d )
6
6
import pandas .compat as compat
7
7
from pandas .core .base import AccessorProperty , NoNewAttributesMixin
8
8
from pandas .util .decorators import Appender , deprecate_kwarg
@@ -1003,7 +1003,7 @@ def str_encode(arr, encoding, errors="strict"):
1003
1003
1004
1004
def _noarg_wrapper (f , docstring = None , ** kargs ):
1005
1005
def wrapper (self ):
1006
- result = _na_map (f , self .series , ** kargs )
1006
+ result = _na_map (f , self ._data , ** kargs )
1007
1007
return self ._wrap_result (result )
1008
1008
1009
1009
wrapper .__name__ = f .__name__
@@ -1017,15 +1017,15 @@ def wrapper(self):
1017
1017
1018
1018
def _pat_wrapper (f , flags = False , na = False , ** kwargs ):
1019
1019
def wrapper1 (self , pat ):
1020
- result = f (self .series , pat )
1020
+ result = f (self ._data , pat )
1021
1021
return self ._wrap_result (result )
1022
1022
1023
1023
def wrapper2 (self , pat , flags = 0 , ** kwargs ):
1024
- result = f (self .series , pat , flags = flags , ** kwargs )
1024
+ result = f (self ._data , pat , flags = flags , ** kwargs )
1025
1025
return self ._wrap_result (result )
1026
1026
1027
1027
def wrapper3 (self , pat , na = np .nan ):
1028
- result = f (self .series , pat , na = na )
1028
+ result = f (self ._data , pat , na = na )
1029
1029
return self ._wrap_result (result )
1030
1030
1031
1031
wrapper = wrapper3 if na else wrapper2 if flags else wrapper1
@@ -1059,8 +1059,11 @@ class StringMethods(NoNewAttributesMixin):
1059
1059
>>> s.str.replace('_', '')
1060
1060
"""
1061
1061
1062
- def __init__ (self , series ):
1063
- self .series = series
1062
+ def __init__ (self , data ):
1063
+ self ._is_categorical = is_categorical_dtype (data )
1064
+ self ._data = data .cat .categories if self ._is_categorical else data
1065
+ # save orig to blow up categoricals to the right type
1066
+ self ._orig = data
1064
1067
self ._freeze ()
1065
1068
1066
1069
def __getitem__ (self , key ):
@@ -1078,7 +1081,15 @@ def __iter__(self):
1078
1081
i += 1
1079
1082
g = self .get (i )
1080
1083
1081
- def _wrap_result (self , result , ** kwargs ):
1084
+ def _wrap_result (self , result , use_codes = True , name = None ):
1085
+
1086
+ # for category, we do the stuff on the categories, so blow it up
1087
+ # to the full series again
1088
+ # But for some operations, we have to do the stuff on the full values,
1089
+ # so make it possible to skip this step as the method already did this before
1090
+ # the transformation...
1091
+ if use_codes and self ._is_categorical :
1092
+ result = take_1d (result , self ._orig .cat .codes )
1082
1093
1083
1094
# leave as it is to keep extract and get_dummies results
1084
1095
# can be merged to _wrap_result_expand in v0.17
@@ -1088,29 +1099,34 @@ def _wrap_result(self, result, **kwargs):
1088
1099
1089
1100
if not hasattr (result , 'ndim' ):
1090
1101
return result
1091
- name = kwargs . get ( ' name' ) or getattr (result , 'name' , None ) or self .series .name
1102
+ name = name or getattr (result , 'name' , None ) or self ._orig .name
1092
1103
1093
1104
if result .ndim == 1 :
1094
- if isinstance (self .series , Index ):
1105
+ if isinstance (self ._orig , Index ):
1095
1106
# if result is a boolean np.array, return the np.array
1096
1107
# instead of wrapping it into a boolean Index (GH 8875)
1097
1108
if is_bool_dtype (result ):
1098
1109
return result
1099
1110
return Index (result , name = name )
1100
- return Series (result , index = self .series .index , name = name )
1111
+ return Series (result , index = self ._orig .index , name = name )
1101
1112
else :
1102
1113
assert result .ndim < 3
1103
- return DataFrame (result , index = self .series .index )
1114
+ return DataFrame (result , index = self ._orig .index )
1104
1115
1105
1116
def _wrap_result_expand (self , result , expand = False ):
1106
1117
if not isinstance (expand , bool ):
1107
1118
raise ValueError ("expand must be True or False" )
1108
1119
1120
+ # for category, we do the stuff on the categories, so blow it up
1121
+ # to the full series again
1122
+ if self ._is_categorical :
1123
+ result = take_1d (result , self ._orig .cat .codes )
1124
+
1109
1125
from pandas .core .index import Index , MultiIndex
1110
1126
if not hasattr (result , 'ndim' ):
1111
1127
return result
1112
1128
1113
- if isinstance (self .series , Index ):
1129
+ if isinstance (self ._orig , Index ):
1114
1130
name = getattr (result , 'name' , None )
1115
1131
# if result is a boolean np.array, return the np.array
1116
1132
# instead of wrapping it into a boolean Index (GH 8875)
@@ -1123,36 +1139,38 @@ def _wrap_result_expand(self, result, expand=False):
1123
1139
else :
1124
1140
return Index (result , name = name )
1125
1141
else :
1126
- index = self .series .index
1142
+ index = self ._orig .index
1127
1143
if expand :
1128
1144
def cons_row (x ):
1129
1145
if is_list_like (x ):
1130
1146
return x
1131
1147
else :
1132
1148
return [ x ]
1133
- cons = self .series ._constructor_expanddim
1149
+ cons = self ._orig ._constructor_expanddim
1134
1150
data = [cons_row (x ) for x in result ]
1135
1151
return cons (data , index = index )
1136
1152
else :
1137
1153
name = getattr (result , 'name' , None )
1138
- cons = self .series ._constructor
1154
+ cons = self ._orig ._constructor
1139
1155
return cons (result , name = name , index = index )
1140
1156
1141
1157
@copy (str_cat )
1142
1158
def cat (self , others = None , sep = None , na_rep = None ):
1143
- result = str_cat (self .series , others = others , sep = sep , na_rep = na_rep )
1144
- return self ._wrap_result (result )
1159
+ data = self ._orig if self ._is_categorical else self ._data
1160
+ result = str_cat (data , others = others , sep = sep , na_rep = na_rep )
1161
+ return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
1162
+
1145
1163
1146
1164
@deprecate_kwarg ('return_type' , 'expand' ,
1147
1165
mapping = {'series' : False , 'frame' : True })
1148
1166
@copy (str_split )
1149
1167
def split (self , pat = None , n = - 1 , expand = False ):
1150
- result = str_split (self .series , pat , n = n )
1168
+ result = str_split (self ._data , pat , n = n )
1151
1169
return self ._wrap_result_expand (result , expand = expand )
1152
1170
1153
1171
@copy (str_rsplit )
1154
1172
def rsplit (self , pat = None , n = - 1 , expand = False ):
1155
- result = str_rsplit (self .series , pat , n = n )
1173
+ result = str_rsplit (self ._data , pat , n = n )
1156
1174
return self ._wrap_result_expand (result , expand = expand )
1157
1175
1158
1176
_shared_docs ['str_partition' ] = ("""
@@ -1203,53 +1221,53 @@ def rsplit(self, pat=None, n=-1, expand=False):
1203
1221
'also' : 'rpartition : Split the string at the last occurrence of `sep`' })
1204
1222
def partition (self , pat = ' ' , expand = True ):
1205
1223
f = lambda x : x .partition (pat )
1206
- result = _na_map (f , self .series )
1224
+ result = _na_map (f , self ._data )
1207
1225
return self ._wrap_result_expand (result , expand = expand )
1208
1226
1209
1227
@Appender (_shared_docs ['str_partition' ] % {'side' : 'last' ,
1210
1228
'return' : '3 elements containing two empty strings, followed by the string itself' ,
1211
1229
'also' : 'partition : Split the string at the first occurrence of `sep`' })
1212
1230
def rpartition (self , pat = ' ' , expand = True ):
1213
1231
f = lambda x : x .rpartition (pat )
1214
- result = _na_map (f , self .series )
1232
+ result = _na_map (f , self ._data )
1215
1233
return self ._wrap_result_expand (result , expand = expand )
1216
1234
1217
1235
@copy (str_get )
1218
1236
def get (self , i ):
1219
- result = str_get (self .series , i )
1237
+ result = str_get (self ._data , i )
1220
1238
return self ._wrap_result (result )
1221
1239
1222
1240
@copy (str_join )
1223
1241
def join (self , sep ):
1224
- result = str_join (self .series , sep )
1242
+ result = str_join (self ._data , sep )
1225
1243
return self ._wrap_result (result )
1226
1244
1227
1245
@copy (str_contains )
1228
1246
def contains (self , pat , case = True , flags = 0 , na = np .nan , regex = True ):
1229
- result = str_contains (self .series , pat , case = case , flags = flags ,
1247
+ result = str_contains (self ._data , pat , case = case , flags = flags ,
1230
1248
na = na , regex = regex )
1231
1249
return self ._wrap_result (result )
1232
1250
1233
1251
@copy (str_match )
1234
1252
def match (self , pat , case = True , flags = 0 , na = np .nan , as_indexer = False ):
1235
- result = str_match (self .series , pat , case = case , flags = flags ,
1253
+ result = str_match (self ._data , pat , case = case , flags = flags ,
1236
1254
na = na , as_indexer = as_indexer )
1237
1255
return self ._wrap_result (result )
1238
1256
1239
1257
@copy (str_replace )
1240
1258
def replace (self , pat , repl , n = - 1 , case = True , flags = 0 ):
1241
- result = str_replace (self .series , pat , repl , n = n , case = case ,
1259
+ result = str_replace (self ._data , pat , repl , n = n , case = case ,
1242
1260
flags = flags )
1243
1261
return self ._wrap_result (result )
1244
1262
1245
1263
@copy (str_repeat )
1246
1264
def repeat (self , repeats ):
1247
- result = str_repeat (self .series , repeats )
1265
+ result = str_repeat (self ._data , repeats )
1248
1266
return self ._wrap_result (result )
1249
1267
1250
1268
@copy (str_pad )
1251
1269
def pad (self , width , side = 'left' , fillchar = ' ' ):
1252
- result = str_pad (self .series , width , side = side , fillchar = fillchar )
1270
+ result = str_pad (self ._data , width , side = side , fillchar = fillchar )
1253
1271
return self ._wrap_result (result )
1254
1272
1255
1273
_shared_docs ['str_pad' ] = ("""
@@ -1297,27 +1315,27 @@ def zfill(self, width):
1297
1315
-------
1298
1316
filled : Series/Index of objects
1299
1317
"""
1300
- result = str_pad (self .series , width , side = 'left' , fillchar = '0' )
1318
+ result = str_pad (self ._data , width , side = 'left' , fillchar = '0' )
1301
1319
return self ._wrap_result (result )
1302
1320
1303
1321
@copy (str_slice )
1304
1322
def slice (self , start = None , stop = None , step = None ):
1305
- result = str_slice (self .series , start , stop , step )
1323
+ result = str_slice (self ._data , start , stop , step )
1306
1324
return self ._wrap_result (result )
1307
1325
1308
1326
@copy (str_slice_replace )
1309
1327
def slice_replace (self , start = None , stop = None , repl = None ):
1310
- result = str_slice_replace (self .series , start , stop , repl )
1328
+ result = str_slice_replace (self ._data , start , stop , repl )
1311
1329
return self ._wrap_result (result )
1312
1330
1313
1331
@copy (str_decode )
1314
1332
def decode (self , encoding , errors = "strict" ):
1315
- result = str_decode (self .series , encoding , errors )
1333
+ result = str_decode (self ._data , encoding , errors )
1316
1334
return self ._wrap_result (result )
1317
1335
1318
1336
@copy (str_encode )
1319
1337
def encode (self , encoding , errors = "strict" ):
1320
- result = str_encode (self .series , encoding , errors )
1338
+ result = str_encode (self ._data , encoding , errors )
1321
1339
return self ._wrap_result (result )
1322
1340
1323
1341
_shared_docs ['str_strip' ] = ("""
@@ -1332,34 +1350,37 @@ def encode(self, encoding, errors="strict"):
1332
1350
@Appender (_shared_docs ['str_strip' ] % dict (side = 'left and right sides' ,
1333
1351
method = 'strip' ))
1334
1352
def strip (self , to_strip = None ):
1335
- result = str_strip (self .series , to_strip , side = 'both' )
1353
+ result = str_strip (self ._data , to_strip , side = 'both' )
1336
1354
return self ._wrap_result (result )
1337
1355
1338
1356
@Appender (_shared_docs ['str_strip' ] % dict (side = 'left side' ,
1339
1357
method = 'lstrip' ))
1340
1358
def lstrip (self , to_strip = None ):
1341
- result = str_strip (self .series , to_strip , side = 'left' )
1359
+ result = str_strip (self ._data , to_strip , side = 'left' )
1342
1360
return self ._wrap_result (result )
1343
1361
1344
1362
@Appender (_shared_docs ['str_strip' ] % dict (side = 'right side' ,
1345
1363
method = 'rstrip' ))
1346
1364
def rstrip (self , to_strip = None ):
1347
- result = str_strip (self .series , to_strip , side = 'right' )
1365
+ result = str_strip (self ._data , to_strip , side = 'right' )
1348
1366
return self ._wrap_result (result )
1349
1367
1350
1368
@copy (str_wrap )
1351
1369
def wrap (self , width , ** kwargs ):
1352
- result = str_wrap (self .series , width , ** kwargs )
1370
+ result = str_wrap (self ._data , width , ** kwargs )
1353
1371
return self ._wrap_result (result )
1354
1372
1355
1373
@copy (str_get_dummies )
1356
1374
def get_dummies (self , sep = '|' ):
1357
- result = str_get_dummies (self .series , sep )
1358
- return self ._wrap_result (result )
1375
+ # we need to cast to Series of strings as only that has all
1376
+ # methods available for making the dummies...
1377
+ data = self ._orig .astype (str ) if self ._is_categorical else self ._data
1378
+ result = str_get_dummies (data , sep )
1379
+ return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
1359
1380
1360
1381
@copy (str_translate )
1361
1382
def translate (self , table , deletechars = None ):
1362
- result = str_translate (self .series , table , deletechars )
1383
+ result = str_translate (self ._data , table , deletechars )
1363
1384
return self ._wrap_result (result )
1364
1385
1365
1386
count = _pat_wrapper (str_count , flags = True )
@@ -1369,7 +1390,7 @@ def translate(self, table, deletechars=None):
1369
1390
1370
1391
@copy (str_extract )
1371
1392
def extract (self , pat , flags = 0 ):
1372
- result , name = str_extract (self .series , pat , flags = flags )
1393
+ result , name = str_extract (self ._data , pat , flags = flags )
1373
1394
return self ._wrap_result (result , name = name )
1374
1395
1375
1396
_shared_docs ['find' ] = ("""
@@ -1398,13 +1419,13 @@ def extract(self, pat, flags=0):
1398
1419
@Appender (_shared_docs ['find' ] % dict (side = 'lowest' , method = 'find' ,
1399
1420
also = 'rfind : Return highest indexes in each strings' ))
1400
1421
def find (self , sub , start = 0 , end = None ):
1401
- result = str_find (self .series , sub , start = start , end = end , side = 'left' )
1422
+ result = str_find (self ._data , sub , start = start , end = end , side = 'left' )
1402
1423
return self ._wrap_result (result )
1403
1424
1404
1425
@Appender (_shared_docs ['find' ] % dict (side = 'highest' , method = 'rfind' ,
1405
1426
also = 'find : Return lowest indexes in each strings' ))
1406
1427
def rfind (self , sub , start = 0 , end = None ):
1407
- result = str_find (self .series , sub , start = start , end = end , side = 'right' )
1428
+ result = str_find (self ._data , sub , start = start , end = end , side = 'right' )
1408
1429
return self ._wrap_result (result )
1409
1430
1410
1431
def normalize (self , form ):
@@ -1423,7 +1444,7 @@ def normalize(self, form):
1423
1444
"""
1424
1445
import unicodedata
1425
1446
f = lambda x : unicodedata .normalize (form , compat .u_safe (x ))
1426
- result = _na_map (f , self .series )
1447
+ result = _na_map (f , self ._data )
1427
1448
return self ._wrap_result (result )
1428
1449
1429
1450
_shared_docs ['index' ] = ("""
@@ -1453,13 +1474,13 @@ def normalize(self, form):
1453
1474
@Appender (_shared_docs ['index' ] % dict (side = 'lowest' , similar = 'find' , method = 'index' ,
1454
1475
also = 'rindex : Return highest indexes in each strings' ))
1455
1476
def index (self , sub , start = 0 , end = None ):
1456
- result = str_index (self .series , sub , start = start , end = end , side = 'left' )
1477
+ result = str_index (self ._data , sub , start = start , end = end , side = 'left' )
1457
1478
return self ._wrap_result (result )
1458
1479
1459
1480
@Appender (_shared_docs ['index' ] % dict (side = 'highest' , similar = 'rfind' , method = 'rindex' ,
1460
1481
also = 'index : Return lowest indexes in each strings' ))
1461
1482
def rindex (self , sub , start = 0 , end = None ):
1462
- result = str_index (self .series , sub , start = start , end = end , side = 'right' )
1483
+ result = str_index (self ._data , sub , start = start , end = end , side = 'right' )
1463
1484
return self ._wrap_result (result )
1464
1485
1465
1486
_shared_docs ['len' ] = ("""
@@ -1553,9 +1574,14 @@ class StringAccessorMixin(object):
1553
1574
def _make_str_accessor (self ):
1554
1575
from pandas .core .series import Series
1555
1576
from pandas .core .index import Index
1556
- if isinstance (self , Series ) and not is_object_dtype (self .dtype ):
1557
- # this really should exclude all series with any non-string values,
1558
- # but that isn't practical for performance reasons until we have a
1577
+ if isinstance (self , Series ) and not (
1578
+ (is_categorical_dtype (self .dtype ) and
1579
+ is_object_dtype (self .values .categories )) or
1580
+ (is_object_dtype (self .dtype ))):
1581
+ # it's neither a string series not a categorical series with strings
1582
+ # inside the categories.
1583
+ # this really should exclude all series with any non-string values (instead of test
1584
+ # for object dtype), but that isn't practical for performance reasons until we have a
1559
1585
# str dtype (GH 9343)
1560
1586
raise AttributeError ("Can only use .str accessor with string "
1561
1587
"values, which use np.object_ dtype in "
0 commit comments