Skip to content

Commit 3ab35b4

Browse files
sinhrksjreback
authored andcommitted
CLN: cleanup strings._wrap_result
- Merged ``strings._wrap_result`` and ``strings._wrap_result_expand`` for cleanup. Author: sinhrks <[email protected]> Closes #12487 from sinhrks/str_expand_cln and squashes the following commits: 6969c95 [sinhrks] CLN: cleanup _wrap_result
1 parent 1343011 commit 3ab35b4

File tree

1 file changed

+39
-49
lines changed

1 file changed

+39
-49
lines changed

pandas/core/strings.py

+39-49
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def str_extract(arr, pat, flags=0, expand=None):
604604
return _str_extract_frame(arr._orig, pat, flags=flags)
605605
else:
606606
result, name = _str_extract_noexpand(arr._data, pat, flags=flags)
607-
return arr._wrap_result(result, name=name)
607+
return arr._wrap_result(result, name=name, expand=expand)
608608

609609

610610
def str_extractall(arr, pat, flags=0):
@@ -1292,7 +1292,10 @@ def __iter__(self):
12921292
i += 1
12931293
g = self.get(i)
12941294

1295-
def _wrap_result(self, result, use_codes=True, name=None):
1295+
def _wrap_result(self, result, use_codes=True,
1296+
name=None, expand=None):
1297+
1298+
from pandas.core.index import Index, MultiIndex
12961299

12971300
# for category, we do the stuff on the categories, so blow it up
12981301
# to the full series again
@@ -1302,48 +1305,42 @@ def _wrap_result(self, result, use_codes=True, name=None):
13021305
if use_codes and self._is_categorical:
13031306
result = take_1d(result, self._orig.cat.codes)
13041307

1305-
# leave as it is to keep extract and get_dummies results
1306-
# can be merged to _wrap_result_expand in v0.17
1307-
from pandas.core.series import Series
1308-
from pandas.core.frame import DataFrame
1309-
from pandas.core.index import Index
1310-
1311-
if not hasattr(result, 'ndim'):
1308+
if not hasattr(result, 'ndim') or not hasattr(result, 'dtype'):
13121309
return result
1310+
assert result.ndim < 3
13131311

1314-
if result.ndim == 1:
1315-
# Wait until we are sure result is a Series or Index before
1316-
# checking attributes (GH 12180)
1317-
name = name or getattr(result, 'name', None) or self._orig.name
1318-
if isinstance(self._orig, Index):
1319-
# if result is a boolean np.array, return the np.array
1320-
# instead of wrapping it into a boolean Index (GH 8875)
1321-
if is_bool_dtype(result):
1322-
return result
1323-
return Index(result, name=name)
1324-
return Series(result, index=self._orig.index, name=name)
1325-
else:
1326-
assert result.ndim < 3
1327-
return DataFrame(result, index=self._orig.index)
1312+
if expand is None:
1313+
# infer from ndim if expand is not specified
1314+
expand = False if result.ndim == 1 else True
1315+
1316+
elif expand is True and not isinstance(self._orig, Index):
1317+
# required when expand=True is explicitly specified
1318+
# not needed when infered
1319+
1320+
def cons_row(x):
1321+
if is_list_like(x):
1322+
return x
1323+
else:
1324+
return [x]
1325+
1326+
result = [cons_row(x) for x in result]
13281327

1329-
def _wrap_result_expand(self, result, expand=False):
13301328
if not isinstance(expand, bool):
13311329
raise ValueError("expand must be True or False")
13321330

1333-
# for category, we do the stuff on the categories, so blow it up
1334-
# to the full series again
1335-
if self._is_categorical:
1336-
result = take_1d(result, self._orig.cat.codes)
1337-
1338-
from pandas.core.index import Index, MultiIndex
1339-
if not hasattr(result, 'ndim'):
1340-
return result
1331+
if name is None:
1332+
name = getattr(result, 'name', None)
1333+
if name is None:
1334+
# do not use logical or, _orig may be a DataFrame
1335+
# which has "name" column
1336+
name = self._orig.name
13411337

1338+
# Wait until we are sure result is a Series or Index before
1339+
# checking attributes (GH 12180)
13421340
if isinstance(self._orig, Index):
1343-
name = getattr(result, 'name', None)
13441341
# if result is a boolean np.array, return the np.array
13451342
# instead of wrapping it into a boolean Index (GH 8875)
1346-
if hasattr(result, 'dtype') and is_bool_dtype(result):
1343+
if is_bool_dtype(result):
13471344
return result
13481345

13491346
if expand:
@@ -1354,18 +1351,10 @@ def _wrap_result_expand(self, result, expand=False):
13541351
else:
13551352
index = self._orig.index
13561353
if expand:
1357-
1358-
def cons_row(x):
1359-
if is_list_like(x):
1360-
return x
1361-
else:
1362-
return [x]
1363-
13641354
cons = self._orig._constructor_expanddim
1365-
data = [cons_row(x) for x in result]
1366-
return cons(data, index=index)
1355+
return cons(result, index=index)
13671356
else:
1368-
name = getattr(result, 'name', None)
1357+
# Must a Series
13691358
cons = self._orig._constructor
13701359
return cons(result, name=name, index=index)
13711360

@@ -1380,12 +1369,12 @@ def cat(self, others=None, sep=None, na_rep=None):
13801369
@copy(str_split)
13811370
def split(self, pat=None, n=-1, expand=False):
13821371
result = str_split(self._data, pat, n=n)
1383-
return self._wrap_result_expand(result, expand=expand)
1372+
return self._wrap_result(result, expand=expand)
13841373

13851374
@copy(str_rsplit)
13861375
def rsplit(self, pat=None, n=-1, expand=False):
13871376
result = str_rsplit(self._data, pat, n=n)
1388-
return self._wrap_result_expand(result, expand=expand)
1377+
return self._wrap_result(result, expand=expand)
13891378

13901379
_shared_docs['str_partition'] = ("""
13911380
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1440,7 +1429,7 @@ def rsplit(self, pat=None, n=-1, expand=False):
14401429
def partition(self, pat=' ', expand=True):
14411430
f = lambda x: x.partition(pat)
14421431
result = _na_map(f, self._data)
1443-
return self._wrap_result_expand(result, expand=expand)
1432+
return self._wrap_result(result, expand=expand)
14441433

14451434
@Appender(_shared_docs['str_partition'] % {
14461435
'side': 'last',
@@ -1451,7 +1440,7 @@ def partition(self, pat=' ', expand=True):
14511440
def rpartition(self, pat=' ', expand=True):
14521441
f = lambda x: x.rpartition(pat)
14531442
result = _na_map(f, self._data)
1454-
return self._wrap_result_expand(result, expand=expand)
1443+
return self._wrap_result(result, expand=expand)
14551444

14561445
@copy(str_get)
14571446
def get(self, i):
@@ -1597,7 +1586,8 @@ def get_dummies(self, sep='|'):
15971586
# methods available for making the dummies...
15981587
data = self._orig.astype(str) if self._is_categorical else self._data
15991588
result = str_get_dummies(data, sep)
1600-
return self._wrap_result(result, use_codes=(not self._is_categorical))
1589+
return self._wrap_result(result, use_codes=(not self._is_categorical),
1590+
expand=True)
16011591

16021592
@copy(str_translate)
16031593
def translate(self, table, deletechars=None):

0 commit comments

Comments
 (0)