Skip to content

Commit 149feef

Browse files
committed
Make .str available for Series of type category with strings
If a series is a type category and the underlying Categorical has categories of type string, then make it possible to use the `.str` assessor on such a series. The string methods work on the categories (and therefor fast if we have only a few categories), but return a Series with a dtype other than category (boolean, string,...), so that it is no different if we use `.str` on a series of type string or of type category.
1 parent 28b7bde commit 149feef

File tree

2 files changed

+148
-50
lines changed

2 files changed

+148
-50
lines changed

pandas/core/strings.py

+76-50
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pandas.compat import zip
44
from pandas.core.common import (isnull, _values_from_object, is_bool_dtype, is_list_like,
5-
is_categorical_dtype, is_object_dtype)
5+
is_categorical_dtype, is_object_dtype, take_1d)
66
import pandas.compat as compat
77
from pandas.core.base import AccessorProperty, NoNewAttributesMixin
88
from pandas.util.decorators import Appender, deprecate_kwarg
@@ -1003,7 +1003,7 @@ def str_encode(arr, encoding, errors="strict"):
10031003

10041004
def _noarg_wrapper(f, docstring=None, **kargs):
10051005
def wrapper(self):
1006-
result = _na_map(f, self.series, **kargs)
1006+
result = _na_map(f, self._data, **kargs)
10071007
return self._wrap_result(result)
10081008

10091009
wrapper.__name__ = f.__name__
@@ -1017,15 +1017,15 @@ def wrapper(self):
10171017

10181018
def _pat_wrapper(f, flags=False, na=False, **kwargs):
10191019
def wrapper1(self, pat):
1020-
result = f(self.series, pat)
1020+
result = f(self._data, pat)
10211021
return self._wrap_result(result)
10221022

10231023
def wrapper2(self, pat, flags=0, **kwargs):
1024-
result = f(self.series, pat, flags=flags, **kwargs)
1024+
result = f(self._data, pat, flags=flags, **kwargs)
10251025
return self._wrap_result(result)
10261026

10271027
def wrapper3(self, pat, na=np.nan):
1028-
result = f(self.series, pat, na=na)
1028+
result = f(self._data, pat, na=na)
10291029
return self._wrap_result(result)
10301030

10311031
wrapper = wrapper3 if na else wrapper2 if flags else wrapper1
@@ -1059,8 +1059,11 @@ class StringMethods(NoNewAttributesMixin):
10591059
>>> s.str.replace('_', '')
10601060
"""
10611061

1062-
def __init__(self, series):
1063-
self.series = series
1062+
def __init__(self, data):
1063+
self._is_categorical = is_categorical_dtype(data)
1064+
self._data = data.cat.categories if self._is_categorical else data
1065+
# save orig to blow up categoricals to the right type
1066+
self._orig = data
10641067
self._freeze()
10651068

10661069
def __getitem__(self, key):
@@ -1078,7 +1081,15 @@ def __iter__(self):
10781081
i += 1
10791082
g = self.get(i)
10801083

1081-
def _wrap_result(self, result, **kwargs):
1084+
def _wrap_result(self, result, use_codes=True, name=None):
1085+
1086+
# for category, we do the stuff on the categories, so blow it up
1087+
# to the full series again
1088+
# But for some operations, we have to do the stuff on the full values,
1089+
# so make it possible to skip this step as the method already did this before
1090+
# the transformation...
1091+
if use_codes and self._is_categorical:
1092+
result = take_1d(result, self._orig.cat.codes)
10821093

10831094
# leave as it is to keep extract and get_dummies results
10841095
# can be merged to _wrap_result_expand in v0.17
@@ -1088,29 +1099,34 @@ def _wrap_result(self, result, **kwargs):
10881099

10891100
if not hasattr(result, 'ndim'):
10901101
return result
1091-
name = kwargs.get('name') or getattr(result, 'name', None) or self.series.name
1102+
name = name or getattr(result, 'name', None) or self._orig.name
10921103

10931104
if result.ndim == 1:
1094-
if isinstance(self.series, Index):
1105+
if isinstance(self._orig, Index):
10951106
# if result is a boolean np.array, return the np.array
10961107
# instead of wrapping it into a boolean Index (GH 8875)
10971108
if is_bool_dtype(result):
10981109
return result
10991110
return Index(result, name=name)
1100-
return Series(result, index=self.series.index, name=name)
1111+
return Series(result, index=self._orig.index, name=name)
11011112
else:
11021113
assert result.ndim < 3
1103-
return DataFrame(result, index=self.series.index)
1114+
return DataFrame(result, index=self._orig.index)
11041115

11051116
def _wrap_result_expand(self, result, expand=False):
11061117
if not isinstance(expand, bool):
11071118
raise ValueError("expand must be True or False")
11081119

1120+
# for category, we do the stuff on the categories, so blow it up
1121+
# to the full series again
1122+
if self._is_categorical:
1123+
result = take_1d(result, self._orig.cat.codes)
1124+
11091125
from pandas.core.index import Index, MultiIndex
11101126
if not hasattr(result, 'ndim'):
11111127
return result
11121128

1113-
if isinstance(self.series, Index):
1129+
if isinstance(self._orig, Index):
11141130
name = getattr(result, 'name', None)
11151131
# if result is a boolean np.array, return the np.array
11161132
# instead of wrapping it into a boolean Index (GH 8875)
@@ -1123,36 +1139,38 @@ def _wrap_result_expand(self, result, expand=False):
11231139
else:
11241140
return Index(result, name=name)
11251141
else:
1126-
index = self.series.index
1142+
index = self._orig.index
11271143
if expand:
11281144
def cons_row(x):
11291145
if is_list_like(x):
11301146
return x
11311147
else:
11321148
return [ x ]
1133-
cons = self.series._constructor_expanddim
1149+
cons = self._orig._constructor_expanddim
11341150
data = [cons_row(x) for x in result]
11351151
return cons(data, index=index)
11361152
else:
11371153
name = getattr(result, 'name', None)
1138-
cons = self.series._constructor
1154+
cons = self._orig._constructor
11391155
return cons(result, name=name, index=index)
11401156

11411157
@copy(str_cat)
11421158
def cat(self, others=None, sep=None, na_rep=None):
1143-
result = str_cat(self.series, others=others, sep=sep, na_rep=na_rep)
1144-
return self._wrap_result(result)
1159+
data = self._orig if self._is_categorical else self._data
1160+
result = str_cat(data, others=others, sep=sep, na_rep=na_rep)
1161+
return self._wrap_result(result, use_codes=(not self._is_categorical))
1162+
11451163

11461164
@deprecate_kwarg('return_type', 'expand',
11471165
mapping={'series': False, 'frame': True})
11481166
@copy(str_split)
11491167
def split(self, pat=None, n=-1, expand=False):
1150-
result = str_split(self.series, pat, n=n)
1168+
result = str_split(self._data, pat, n=n)
11511169
return self._wrap_result_expand(result, expand=expand)
11521170

11531171
@copy(str_rsplit)
11541172
def rsplit(self, pat=None, n=-1, expand=False):
1155-
result = str_rsplit(self.series, pat, n=n)
1173+
result = str_rsplit(self._data, pat, n=n)
11561174
return self._wrap_result_expand(result, expand=expand)
11571175

11581176
_shared_docs['str_partition'] = ("""
@@ -1203,53 +1221,53 @@ def rsplit(self, pat=None, n=-1, expand=False):
12031221
'also': 'rpartition : Split the string at the last occurrence of `sep`'})
12041222
def partition(self, pat=' ', expand=True):
12051223
f = lambda x: x.partition(pat)
1206-
result = _na_map(f, self.series)
1224+
result = _na_map(f, self._data)
12071225
return self._wrap_result_expand(result, expand=expand)
12081226

12091227
@Appender(_shared_docs['str_partition'] % {'side': 'last',
12101228
'return': '3 elements containing two empty strings, followed by the string itself',
12111229
'also': 'partition : Split the string at the first occurrence of `sep`'})
12121230
def rpartition(self, pat=' ', expand=True):
12131231
f = lambda x: x.rpartition(pat)
1214-
result = _na_map(f, self.series)
1232+
result = _na_map(f, self._data)
12151233
return self._wrap_result_expand(result, expand=expand)
12161234

12171235
@copy(str_get)
12181236
def get(self, i):
1219-
result = str_get(self.series, i)
1237+
result = str_get(self._data, i)
12201238
return self._wrap_result(result)
12211239

12221240
@copy(str_join)
12231241
def join(self, sep):
1224-
result = str_join(self.series, sep)
1242+
result = str_join(self._data, sep)
12251243
return self._wrap_result(result)
12261244

12271245
@copy(str_contains)
12281246
def contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
1229-
result = str_contains(self.series, pat, case=case, flags=flags,
1247+
result = str_contains(self._data, pat, case=case, flags=flags,
12301248
na=na, regex=regex)
12311249
return self._wrap_result(result)
12321250

12331251
@copy(str_match)
12341252
def match(self, pat, case=True, flags=0, na=np.nan, as_indexer=False):
1235-
result = str_match(self.series, pat, case=case, flags=flags,
1253+
result = str_match(self._data, pat, case=case, flags=flags,
12361254
na=na, as_indexer=as_indexer)
12371255
return self._wrap_result(result)
12381256

12391257
@copy(str_replace)
12401258
def replace(self, pat, repl, n=-1, case=True, flags=0):
1241-
result = str_replace(self.series, pat, repl, n=n, case=case,
1259+
result = str_replace(self._data, pat, repl, n=n, case=case,
12421260
flags=flags)
12431261
return self._wrap_result(result)
12441262

12451263
@copy(str_repeat)
12461264
def repeat(self, repeats):
1247-
result = str_repeat(self.series, repeats)
1265+
result = str_repeat(self._data, repeats)
12481266
return self._wrap_result(result)
12491267

12501268
@copy(str_pad)
12511269
def pad(self, width, side='left', fillchar=' '):
1252-
result = str_pad(self.series, width, side=side, fillchar=fillchar)
1270+
result = str_pad(self._data, width, side=side, fillchar=fillchar)
12531271
return self._wrap_result(result)
12541272

12551273
_shared_docs['str_pad'] = ("""
@@ -1297,27 +1315,27 @@ def zfill(self, width):
12971315
-------
12981316
filled : Series/Index of objects
12991317
"""
1300-
result = str_pad(self.series, width, side='left', fillchar='0')
1318+
result = str_pad(self._data, width, side='left', fillchar='0')
13011319
return self._wrap_result(result)
13021320

13031321
@copy(str_slice)
13041322
def slice(self, start=None, stop=None, step=None):
1305-
result = str_slice(self.series, start, stop, step)
1323+
result = str_slice(self._data, start, stop, step)
13061324
return self._wrap_result(result)
13071325

13081326
@copy(str_slice_replace)
13091327
def slice_replace(self, start=None, stop=None, repl=None):
1310-
result = str_slice_replace(self.series, start, stop, repl)
1328+
result = str_slice_replace(self._data, start, stop, repl)
13111329
return self._wrap_result(result)
13121330

13131331
@copy(str_decode)
13141332
def decode(self, encoding, errors="strict"):
1315-
result = str_decode(self.series, encoding, errors)
1333+
result = str_decode(self._data, encoding, errors)
13161334
return self._wrap_result(result)
13171335

13181336
@copy(str_encode)
13191337
def encode(self, encoding, errors="strict"):
1320-
result = str_encode(self.series, encoding, errors)
1338+
result = str_encode(self._data, encoding, errors)
13211339
return self._wrap_result(result)
13221340

13231341
_shared_docs['str_strip'] = ("""
@@ -1332,34 +1350,37 @@ def encode(self, encoding, errors="strict"):
13321350
@Appender(_shared_docs['str_strip'] % dict(side='left and right sides',
13331351
method='strip'))
13341352
def strip(self, to_strip=None):
1335-
result = str_strip(self.series, to_strip, side='both')
1353+
result = str_strip(self._data, to_strip, side='both')
13361354
return self._wrap_result(result)
13371355

13381356
@Appender(_shared_docs['str_strip'] % dict(side='left side',
13391357
method='lstrip'))
13401358
def lstrip(self, to_strip=None):
1341-
result = str_strip(self.series, to_strip, side='left')
1359+
result = str_strip(self._data, to_strip, side='left')
13421360
return self._wrap_result(result)
13431361

13441362
@Appender(_shared_docs['str_strip'] % dict(side='right side',
13451363
method='rstrip'))
13461364
def rstrip(self, to_strip=None):
1347-
result = str_strip(self.series, to_strip, side='right')
1365+
result = str_strip(self._data, to_strip, side='right')
13481366
return self._wrap_result(result)
13491367

13501368
@copy(str_wrap)
13511369
def wrap(self, width, **kwargs):
1352-
result = str_wrap(self.series, width, **kwargs)
1370+
result = str_wrap(self._data, width, **kwargs)
13531371
return self._wrap_result(result)
13541372

13551373
@copy(str_get_dummies)
13561374
def get_dummies(self, sep='|'):
1357-
result = str_get_dummies(self.series, sep)
1358-
return self._wrap_result(result)
1375+
# we need to cast to Series of strings as only that has all
1376+
# methods available for making the dummies...
1377+
data = self._orig.astype(str) if self._is_categorical else self._data
1378+
result = str_get_dummies(data, sep)
1379+
return self._wrap_result(result, use_codes=(not self._is_categorical))
13591380

13601381
@copy(str_translate)
13611382
def translate(self, table, deletechars=None):
1362-
result = str_translate(self.series, table, deletechars)
1383+
result = str_translate(self._data, table, deletechars)
13631384
return self._wrap_result(result)
13641385

13651386
count = _pat_wrapper(str_count, flags=True)
@@ -1369,7 +1390,7 @@ def translate(self, table, deletechars=None):
13691390

13701391
@copy(str_extract)
13711392
def extract(self, pat, flags=0):
1372-
result, name = str_extract(self.series, pat, flags=flags)
1393+
result, name = str_extract(self._data, pat, flags=flags)
13731394
return self._wrap_result(result, name=name)
13741395

13751396
_shared_docs['find'] = ("""
@@ -1398,13 +1419,13 @@ def extract(self, pat, flags=0):
13981419
@Appender(_shared_docs['find'] % dict(side='lowest', method='find',
13991420
also='rfind : Return highest indexes in each strings'))
14001421
def find(self, sub, start=0, end=None):
1401-
result = str_find(self.series, sub, start=start, end=end, side='left')
1422+
result = str_find(self._data, sub, start=start, end=end, side='left')
14021423
return self._wrap_result(result)
14031424

14041425
@Appender(_shared_docs['find'] % dict(side='highest', method='rfind',
14051426
also='find : Return lowest indexes in each strings'))
14061427
def rfind(self, sub, start=0, end=None):
1407-
result = str_find(self.series, sub, start=start, end=end, side='right')
1428+
result = str_find(self._data, sub, start=start, end=end, side='right')
14081429
return self._wrap_result(result)
14091430

14101431
def normalize(self, form):
@@ -1423,7 +1444,7 @@ def normalize(self, form):
14231444
"""
14241445
import unicodedata
14251446
f = lambda x: unicodedata.normalize(form, compat.u_safe(x))
1426-
result = _na_map(f, self.series)
1447+
result = _na_map(f, self._data)
14271448
return self._wrap_result(result)
14281449

14291450
_shared_docs['index'] = ("""
@@ -1453,13 +1474,13 @@ def normalize(self, form):
14531474
@Appender(_shared_docs['index'] % dict(side='lowest', similar='find', method='index',
14541475
also='rindex : Return highest indexes in each strings'))
14551476
def index(self, sub, start=0, end=None):
1456-
result = str_index(self.series, sub, start=start, end=end, side='left')
1477+
result = str_index(self._data, sub, start=start, end=end, side='left')
14571478
return self._wrap_result(result)
14581479

14591480
@Appender(_shared_docs['index'] % dict(side='highest', similar='rfind', method='rindex',
14601481
also='index : Return lowest indexes in each strings'))
14611482
def rindex(self, sub, start=0, end=None):
1462-
result = str_index(self.series, sub, start=start, end=end, side='right')
1483+
result = str_index(self._data, sub, start=start, end=end, side='right')
14631484
return self._wrap_result(result)
14641485

14651486
_shared_docs['len'] = ("""
@@ -1553,9 +1574,14 @@ class StringAccessorMixin(object):
15531574
def _make_str_accessor(self):
15541575
from pandas.core.series import Series
15551576
from pandas.core.index import Index
1556-
if isinstance(self, Series) and not is_object_dtype(self.dtype):
1557-
# this really should exclude all series with any non-string values,
1558-
# but that isn't practical for performance reasons until we have a
1577+
if isinstance(self, Series) and not(
1578+
(is_categorical_dtype(self.dtype) and
1579+
is_object_dtype(self.values.categories)) or
1580+
(is_object_dtype(self.dtype))):
1581+
# it's neither a string series not a categorical series with strings
1582+
# inside the categories.
1583+
# this really should exclude all series with any non-string values (instead of test
1584+
# for object dtype), but that isn't practical for performance reasons until we have a
15591585
# str dtype (GH 9343)
15601586
raise AttributeError("Can only use .str accessor with string "
15611587
"values, which use np.object_ dtype in "

0 commit comments

Comments
 (0)