diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 483cf659080ea..12aa4d4b780ef 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -167,7 +167,7 @@ Performance improvements - Performance improvement in :meth:`MultiIndex.equals` for equal length indexes (:issue:`56990`) - Performance improvement in :meth:`RangeIndex.append` when appending the same index (:issue:`57252`) - Performance improvement in indexing operations for string dtypes (:issue:`56997`) -- +- :meth:`Series.str.extract` returns a :class:`RangeIndex` columns instead of an :class:`Index` column when possible (:issue:`?``) .. --------------------------------------------------------------------------- .. _whatsnew_300.bug_fixes: diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index fe68107a953bb..f0da00645b46f 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -3557,7 +3557,7 @@ def _get_single_group_name(regex: re.Pattern) -> Hashable: return None -def _get_group_names(regex: re.Pattern) -> list[Hashable]: +def _get_group_names(regex: re.Pattern) -> Index: """ Get named groups from compiled regex. @@ -3569,10 +3569,24 @@ def _get_group_names(regex: re.Pattern) -> list[Hashable]: Returns ------- - list of column labels + Index """ + from pandas import Index + names = {v: k for k, v in regex.groupindex.items()} - return [names.get(1 + i, i) for i in range(regex.groups)] + if not names: + return Index(range(1)) + return_rangeindex = True + result = [] + for i in range(regex.groups): + name = names.get(1 + i, i) + if return_rangeindex and name != i: + return_rangeindex = False + result.append(name) + if return_rangeindex: + return Index(range(regex.groups)) + else: + return Index(result) def str_extractall(arr, pat, flags: int = 0) -> DataFrame: diff --git a/pandas/tests/strings/test_extract.py b/pandas/tests/strings/test_extract.py index 7ebcbdc7a8533..04209e14bebbb 100644 --- a/pandas/tests/strings/test_extract.py +++ b/pandas/tests/strings/test_extract.py @@ -372,7 +372,7 @@ def test_extract_dataframe_capture_groups_index(index, any_string_dtype): result = s.str.extract(r"(\d)", expand=True) expected = DataFrame(["1", "2", np.nan], index=index, dtype=any_string_dtype) - tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result, expected, check_column_type=True) result = s.str.extract(r"(?P\D)(?P\d)?", expand=True) expected = DataFrame(