Skip to content

Commit 2122963

Browse files
committed
simplify str_extract(), pass name into _wrap_result()
1 parent fee0a7b commit 2122963

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

pandas/core/strings.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -484,17 +484,12 @@ def f(x):
484484
return empty_row
485485

486486
if regex.groups == 1:
487-
if isinstance(arr, Index):
488-
result = Index([f(val)[0] for val in arr],
489-
name=_get_single_group_name(regex),
490-
dtype=object)
491-
else:
492-
result = Series([f(val)[0] for val in arr],
493-
name=_get_single_group_name(regex),
494-
index=arr.index, dtype=object)
487+
result = np.array([f(val)[0] for val in arr], dtype=object)
488+
name = _get_single_group_name(regex)
495489
else:
496490
if isinstance(arr, Index):
497491
raise ValueError("only one regex group is supported with Index")
492+
name = None
498493
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
499494
columns = [names.get(1 + i, i) for i in range(regex.groups)]
500495
if arr.empty:
@@ -504,7 +499,7 @@ def f(x):
504499
columns=columns,
505500
index=arr.index,
506501
dtype=object)
507-
return result
502+
return result, name
508503

509504

510505
def str_get_dummies(arr, sep='|'):
@@ -1005,7 +1000,7 @@ def __iter__(self):
10051000
i += 1
10061001
g = self.get(i)
10071002

1008-
def _wrap_result(self, result):
1003+
def _wrap_result(self, result, **kwargs):
10091004
# leave as it is to keep extract and get_dummies results
10101005
# can be merged to _wrap_result_expand in v0.17
10111006
from pandas.core.series import Series
@@ -1014,16 +1009,16 @@ def _wrap_result(self, result):
10141009

10151010
if not hasattr(result, 'ndim'):
10161011
return result
1017-
elif result.ndim == 1:
1018-
name = getattr(result, 'name', None)
1012+
name = kwargs.get('name') or getattr(result, 'name', None) or self.series.name
1013+
1014+
if result.ndim == 1:
10191015
if isinstance(self.series, Index):
10201016
# if result is a boolean np.array, return the np.array
10211017
# instead of wrapping it into a boolean Index (GH 8875)
10221018
if is_bool_dtype(result):
10231019
return result
1024-
return Index(result, name=name or self.series.name)
1025-
return Series(result, index=self.series.index,
1026-
name=name or self.series.name)
1020+
return Index(result, name=name)
1021+
return Series(result, index=self.series.index, name=name)
10271022
else:
10281023
assert result.ndim < 3
10291024
return DataFrame(result, index=self.series.index)
@@ -1271,7 +1266,11 @@ def get_dummies(self, sep='|'):
12711266
startswith = _pat_wrapper(str_startswith, na=True)
12721267
endswith = _pat_wrapper(str_endswith, na=True)
12731268
findall = _pat_wrapper(str_findall, flags=True)
1274-
extract = _pat_wrapper(str_extract, flags=True)
1269+
1270+
@copy(str_extract)
1271+
def extract(self, pat, flags=0):
1272+
result, name = str_extract(self.series, pat, flags=flags)
1273+
return self._wrap_result(result, name=name)
12751274

12761275
_shared_docs['find'] = ("""
12771276
Return %(side)s indexes in each strings in the Series/Index

0 commit comments

Comments
 (0)