Skip to content

Commit a0545a8

Browse files
committed
simplify str_extract(), pass name into _wrap_result()
1 parent fee0a7b commit a0545a8

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

pandas/core/strings.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -484,17 +484,12 @@ def f(x):
484484
return empty_row
485485

486486
if regex.groups == 1:
487-
if isinstance(arr, Index):
488-
result = Index([f(val)[0] for val in arr],
489-
name=_get_single_group_name(regex),
490-
dtype=object)
491-
else:
492-
result = Series([f(val)[0] for val in arr],
493-
name=_get_single_group_name(regex),
494-
index=arr.index, dtype=object)
487+
result = np.array([f(val)[0] for val in arr], dtype=object)
488+
name = _get_single_group_name(regex)
495489
else:
496490
if isinstance(arr, Index):
497491
raise ValueError("only one regex group is supported with Index")
492+
name = None
498493
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
499494
columns = [names.get(1 + i, i) for i in range(regex.groups)]
500495
if arr.empty:
@@ -504,7 +499,7 @@ def f(x):
504499
columns=columns,
505500
index=arr.index,
506501
dtype=object)
507-
return result
502+
return result, name
508503

509504

510505
def str_get_dummies(arr, sep='|'):
@@ -1005,7 +1000,7 @@ def __iter__(self):
10051000
i += 1
10061001
g = self.get(i)
10071002

1008-
def _wrap_result(self, result):
1003+
def _wrap_result(self, result, **kwargs):
10091004
# leave as it is to keep extract and get_dummies results
10101005
# can be merged to _wrap_result_expand in v0.17
10111006
from pandas.core.series import Series
@@ -1014,16 +1009,20 @@ def _wrap_result(self, result):
10141009

10151010
if not hasattr(result, 'ndim'):
10161011
return result
1017-
elif result.ndim == 1:
1018-
name = getattr(result, 'name', None)
1012+
1013+
if 'name' in kwargs:
1014+
name = kwargs['name']
1015+
else:
1016+
name = getattr(result, 'name', None) or self.series.name
1017+
1018+
if result.ndim == 1:
10191019
if isinstance(self.series, Index):
10201020
# if result is a boolean np.array, return the np.array
10211021
# instead of wrapping it into a boolean Index (GH 8875)
10221022
if is_bool_dtype(result):
10231023
return result
1024-
return Index(result, name=name or self.series.name)
1025-
return Series(result, index=self.series.index,
1026-
name=name or self.series.name)
1024+
return Index(result, name=name)
1025+
return Series(result, index=self.series.index, name=name)
10271026
else:
10281027
assert result.ndim < 3
10291028
return DataFrame(result, index=self.series.index)
@@ -1271,7 +1270,11 @@ def get_dummies(self, sep='|'):
12711270
startswith = _pat_wrapper(str_startswith, na=True)
12721271
endswith = _pat_wrapper(str_endswith, na=True)
12731272
findall = _pat_wrapper(str_findall, flags=True)
1274-
extract = _pat_wrapper(str_extract, flags=True)
1273+
1274+
@copy(str_extract)
1275+
def extract(self, pat, flags=0):
1276+
result, name = str_extract(self.series, pat, flags=flags)
1277+
return self._wrap_result(result, name=name)
12751278

12761279
_shared_docs['find'] = ("""
12771280
Return %(side)s indexes in each strings in the Series/Index

0 commit comments

Comments
 (0)