From 43dfd983a953ccce146a4ff80cd4b9aa3312be6e Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Tue, 11 May 2021 11:39:46 +0100 Subject: [PATCH] [ArrowStringArray] REF: extract/extractall column names --- pandas/core/strings/accessor.py | 37 ++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 9f8c9fa2f0515..55a12a301c6e6 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -3,8 +3,10 @@ import re from typing import ( Dict, + Hashable, List, Optional, + Pattern, ) import warnings @@ -3036,13 +3038,31 @@ def _result_dtype(arr): return object -def _get_single_group_name(rx): - try: - return list(rx.groupindex.keys()).pop() - except IndexError: +def _get_single_group_name(regex: Pattern) -> Hashable: + if regex.groupindex: + return next(iter(regex.groupindex)) + else: return None +def _get_group_names(regex: Pattern) -> List[Hashable]: + """ + Get named groups from compiled regex. + + Unnamed groups are numbered. + + Parameters + ---------- + regex : compiled regex + + Returns + ------- + list of column labels + """ + names = {v: k for k, v in regex.groupindex.items()} + return [names.get(1 + i, i) for i in range(regex.groups)] + + def _str_extract_noexpand(arr, pat, flags=0): """ Find groups in each string in the Series using passed regular @@ -3069,8 +3089,7 @@ def _str_extract_noexpand(arr, pat, flags=0): if isinstance(arr, ABCIndex): raise ValueError("only one regex group is supported with Index") name = None - names = dict(zip(regex.groupindex.values(), regex.groupindex.keys())) - columns = [names.get(1 + i, i) for i in range(regex.groups)] + columns = _get_group_names(regex) if arr.size == 0: # error: Incompatible types in assignment (expression has type # "DataFrame", variable has type "ndarray") @@ -3101,8 +3120,7 @@ def _str_extract_frame(arr, pat, flags=0): regex = re.compile(pat, flags=flags) groups_or_na = _groups_or_na_fun(regex) - names = dict(zip(regex.groupindex.values(), regex.groupindex.keys())) - columns = [names.get(1 + i, i) for i in range(regex.groups)] + columns = _get_group_names(regex) if len(arr) == 0: return DataFrame(columns=columns, dtype=object) @@ -3139,8 +3157,7 @@ def str_extractall(arr, pat, flags=0): if isinstance(arr, ABCIndex): arr = arr.to_series().reset_index(drop=True) - names = dict(zip(regex.groupindex.values(), regex.groupindex.keys())) - columns = [names.get(1 + i, i) for i in range(regex.groups)] + columns = _get_group_names(regex) match_list = [] index_list = [] is_mi = arr.index.nlevels > 1