Skip to content

Commit e686387

Browse files
mortadajreback
authored andcommitted
ENH: improve extract and get_dummies methods for Index.str (fix for #9980)
simplify str_extract(), pass name into _wrap_result()
1 parent be2a9f8 commit e686387

File tree

3 files changed

+63
-29
lines changed

3 files changed

+63
-29
lines changed

doc/source/whatsnew/v0.16.1.txt

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Enhancements
4040
Timestamp('2014-08-01 16:30') + BusinessHour()
4141

4242
- ``DataFrame.diff`` now takes an ``axis`` parameter that determines the direction of differencing (:issue:`9727`)
43+
4344
- Allow clip, clip_lower, and clip_upper to accept array-like arguments as thresholds (:issue:`6966`). These methods now have an ``axis`` parameter which determines how the Series or DataFrame will be aligned with the threshold(s).
4445

4546
- ``DataFrame.mask()`` and ``Series.mask()`` now support same keywords as ``where`` (:issue:`8801`)
@@ -216,6 +217,8 @@ enhancements are performed to make string operation easier.
216217
idx.str.startswith('a')
217218
s[s.index.str.startswith('a')]
218219

220+
- Improved ``extract`` and ``get_dummies`` methods for ``Index.str`` (:issue:`9980`)
221+
219222
.. _whatsnew_0161.api:
220223

221224
API changes

pandas/core/strings.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def str_extract(arr, pat, flags=0):
466466
"""
467467
from pandas.core.series import Series
468468
from pandas.core.frame import DataFrame
469+
from pandas.core.index import Index
469470

470471
regex = re.compile(pat, flags=flags)
471472
# just to be safe, check this
@@ -481,11 +482,14 @@ def f(x):
481482
return [np.nan if item is None else item for item in m.groups()]
482483
else:
483484
return empty_row
485+
484486
if regex.groups == 1:
485-
result = Series([f(val)[0] for val in arr],
486-
name=_get_single_group_name(regex),
487-
index=arr.index, dtype=object)
487+
result = np.array([f(val)[0] for val in arr], dtype=object)
488+
name = _get_single_group_name(regex)
488489
else:
490+
if isinstance(arr, Index):
491+
raise ValueError("only one regex group is supported with Index")
492+
name = None
489493
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
490494
columns = [names.get(1 + i, i) for i in range(regex.groups)]
491495
if arr.empty:
@@ -495,7 +499,7 @@ def f(x):
495499
columns=columns,
496500
index=arr.index,
497501
dtype=object)
498-
return result
502+
return result, name
499503

500504

501505
def str_get_dummies(arr, sep='|'):
@@ -531,6 +535,11 @@ def str_get_dummies(arr, sep='|'):
531535
pandas.get_dummies
532536
"""
533537
from pandas.core.frame import DataFrame
538+
from pandas.core.index import Index
539+
540+
# GH9980, Index.str does not support get_dummies() as it returns a frame
541+
if isinstance(arr, Index):
542+
raise TypeError("get_dummies is not supported for string methods on Index")
534543

535544
# TODO remove this hack?
536545
arr = arr.fillna('')
@@ -991,7 +1000,7 @@ def __iter__(self):
9911000
i += 1
9921001
g = self.get(i)
9931002

994-
def _wrap_result(self, result):
1003+
def _wrap_result(self, result, **kwargs):
9951004
# leave as it is to keep extract and get_dummies results
9961005
# can be merged to _wrap_result_expand in v0.17
9971006
from pandas.core.series import Series
@@ -1000,16 +1009,16 @@ def _wrap_result(self, result):
10001009

10011010
if not hasattr(result, 'ndim'):
10021011
return result
1003-
elif result.ndim == 1:
1004-
name = getattr(result, 'name', None)
1012+
name = kwargs.get('name') or getattr(result, 'name', None) or self.series.name
1013+
1014+
if result.ndim == 1:
10051015
if isinstance(self.series, Index):
10061016
# if result is a boolean np.array, return the np.array
10071017
# instead of wrapping it into a boolean Index (GH 8875)
10081018
if is_bool_dtype(result):
10091019
return result
1010-
return Index(result, name=name or self.series.name)
1011-
return Series(result, index=self.series.index,
1012-
name=name or self.series.name)
1020+
return Index(result, name=name)
1021+
return Series(result, index=self.series.index, name=name)
10131022
else:
10141023
assert result.ndim < 3
10151024
return DataFrame(result, index=self.series.index)
@@ -1257,7 +1266,11 @@ def get_dummies(self, sep='|'):
12571266
startswith = _pat_wrapper(str_startswith, na=True)
12581267
endswith = _pat_wrapper(str_endswith, na=True)
12591268
findall = _pat_wrapper(str_findall, flags=True)
1260-
extract = _pat_wrapper(str_extract, flags=True)
1269+
1270+
@copy(str_extract)
1271+
def extract(self, pat, flags=0):
1272+
result, name = str_extract(self.series, pat, flags=flags)
1273+
return self._wrap_result(result, name=name)
12611274

12621275
_shared_docs['find'] = ("""
12631276
Return %(side)s indexes in each strings in the Series/Index

pandas/tests/test_strings.py

+36-18
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,6 @@ def test_match(self):
516516

517517
def test_extract(self):
518518
# Contains tests like those in test_match and some others.
519-
520519
values = Series(['fooBAD__barBAD', NA, 'foo'])
521520
er = [NA, NA] # empty row
522521

@@ -540,15 +539,31 @@ def test_extract(self):
540539
exp = DataFrame([[u('BAD__'), u('BAD')], er, er])
541540
tm.assert_frame_equal(result, exp)
542541

543-
# no groups
544-
s = Series(['A1', 'B2', 'C3'])
545-
f = lambda: s.str.extract('[ABC][123]')
546-
self.assertRaises(ValueError, f)
547-
548-
# only non-capturing groups
549-
f = lambda: s.str.extract('(?:[AB]).*')
550-
self.assertRaises(ValueError, f)
542+
# GH9980
543+
# Index only works with one regex group since
544+
# multi-group would expand to a frame
545+
idx = Index(['A1', 'A2', 'A3', 'A4', 'B5'])
546+
with tm.assertRaisesRegexp(ValueError, "supported"):
547+
idx.str.extract('([AB])([123])')
548+
549+
# these should work for both Series and Index
550+
for klass in [Series, Index]:
551+
# no groups
552+
s_or_idx = klass(['A1', 'B2', 'C3'])
553+
f = lambda: s_or_idx.str.extract('[ABC][123]')
554+
self.assertRaises(ValueError, f)
555+
556+
# only non-capturing groups
557+
f = lambda: s_or_idx.str.extract('(?:[AB]).*')
558+
self.assertRaises(ValueError, f)
559+
560+
# single group renames series/index properly
561+
s_or_idx = klass(['A1', 'A2'])
562+
result = s_or_idx.str.extract(r'(?P<uno>A)\d')
563+
tm.assert_equal(result.name, 'uno')
564+
tm.assert_array_equal(result, klass(['A', 'A']))
551565

566+
s = Series(['A1', 'B2', 'C3'])
552567
# one group, no matches
553568
result = s.str.extract('(_)')
554569
exp = Series([NA, NA, NA], dtype=object)
@@ -569,14 +584,16 @@ def test_extract(self):
569584
exp = DataFrame([['A', '1'], ['B', '2'], [NA, NA]])
570585
tm.assert_frame_equal(result, exp)
571586

572-
# named group/groups
573-
result = s.str.extract('(?P<letter>[AB])(?P<number>[123])')
574-
exp = DataFrame([['A', '1'], ['B', '2'], [NA, NA]], columns=['letter', 'number'])
575-
tm.assert_frame_equal(result, exp)
587+
# one named group
576588
result = s.str.extract('(?P<letter>[AB])')
577589
exp = Series(['A', 'B', NA], name='letter')
578590
tm.assert_series_equal(result, exp)
579591

592+
# two named groups
593+
result = s.str.extract('(?P<letter>[AB])(?P<number>[123])')
594+
exp = DataFrame([['A', '1'], ['B', '2'], [NA, NA]], columns=['letter', 'number'])
595+
tm.assert_frame_equal(result, exp)
596+
580597
# mix named and unnamed groups
581598
result = s.str.extract('([AB])(?P<number>[123])')
582599
exp = DataFrame([['A', '1'], ['B', '2'], [NA, NA]], columns=[0, 'number'])
@@ -602,11 +619,6 @@ def test_extract(self):
602619
exp = DataFrame([['A', '1'], ['B', '2'], ['C', NA]], columns=['letter', 'number'])
603620
tm.assert_frame_equal(result, exp)
604621

605-
# single group renames series properly
606-
s = Series(['A1', 'A2'])
607-
result = s.str.extract(r'(?P<uno>A)\d')
608-
tm.assert_equal(result.name, 'uno')
609-
610622
# GH6348
611623
# not passing index to the extractor
612624
def check_index(index):
@@ -761,6 +773,12 @@ def test_get_dummies(self):
761773
columns=list('7ab'))
762774
tm.assert_frame_equal(result, expected)
763775

776+
# GH9980
777+
# Index.str does not support get_dummies() as it returns a frame
778+
with tm.assertRaisesRegexp(TypeError, "not supported"):
779+
idx = Index(['a|b', 'a|c', 'b|c'])
780+
idx.str.get_dummies('|')
781+
764782
def test_join(self):
765783
values = Series(['a_b_c', 'c_d_e', np.nan, 'f_g_h'])
766784
result = values.str.split('_').str.join('_')

0 commit comments

Comments
 (0)