From 0c5cf2abbaec76a1f54ed03144aecbab195f0f1b Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Fri, 7 May 2021 16:55:48 +0100 Subject: [PATCH 1/4] first commit --- pandas/core/arrays/categorical.py | 4 +- pandas/core/arrays/string_.py | 4 +- pandas/core/arrays/string_arrow.py | 4 +- pandas/core/strings/accessor.py | 170 ++++++--------------------- pandas/core/strings/base.py | 4 + pandas/core/strings/object_array.py | 55 ++++++++- pandas/tests/strings/test_extract.py | 99 ++++++++-------- pandas/tests/strings/test_strings.py | 10 +- 8 files changed, 155 insertions(+), 195 deletions(-) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index a82c75f4b2557..2005da974bb49 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2449,7 +2449,9 @@ def replace(self, to_replace, value, inplace: bool = False): # ------------------------------------------------------------------------ # String methods interface - def _str_map(self, f, na_value=np.nan, dtype=np.dtype("object")): + def _str_map( + self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True + ): # Optimization to apply the callable `f` to the categories once # and rebuild the result by `take`ing from the result with the codes. # Returns the same type as the object-dtype implementation though. diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 74ca5130ca322..ab1dadf4d2dfa 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -410,7 +410,9 @@ def _cmp_method(self, other, op): # String methods interface _str_na_value = StringDtype.na_value - def _str_map(self, f, na_value=None, dtype: Dtype | None = None): + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): from pandas.arrays import BooleanArray if dtype is None: diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 219a8c7ec0b82..ebc7345bce527 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -741,7 +741,9 @@ def value_counts(self, dropna: bool = True) -> Series: _str_na_value = ArrowStringDtype.na_value - def _str_map(self, f, na_value=None, dtype: Dtype | None = None): + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): # TODO: de-duplicate with StringArray method. This method is moreless copy and # paste. diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index f8df05a7022d1..96d3de17fb5b9 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -11,6 +11,7 @@ import numpy as np import pandas._libs.lib as lib +from pandas._typing import IndexLabel from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -150,7 +151,6 @@ class StringMethods(NoNewAttributesMixin): # TODO: Dispatch all the methods # Currently the following are not dispatched to the array # * cat - # * extract # * extractall def __init__(self, data): @@ -233,7 +233,7 @@ def _wrap_result( self, result, name=None, - expand=None, + expand: Optional[bool] = None, fill_value=np.nan, returns_string=True, ): @@ -257,24 +257,6 @@ def _wrap_result( # infer from ndim if expand is not specified expand = result.ndim != 1 - elif expand is True and not isinstance(self._orig, ABCIndex): - # required when expand=True is explicitly specified - # not needed when inferred - - def cons_row(x): - if is_list_like(x): - return x - else: - return [x] - - result = [cons_row(x) for x in result] - if result: - # propagate nan values to match longest sequence (GH 18450) - max_len = max(len(x) for x in result) - result = [ - x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result - ] - if not isinstance(expand, bool): raise ValueError("expand must be True or False") @@ -310,14 +292,14 @@ def cons_row(x): index = self._orig.index # This is a mess. dtype: Optional[str] - if self._is_string and returns_string: + if not self._is_categorical and returns_string: dtype = self._orig.dtype else: dtype = None if expand: cons = self._orig._constructor_expanddim - result = cons(result, columns=name, index=index, dtype=dtype) + result = cons(list(result), columns=name, index=index, dtype=dtype) else: # Must be a Series cons = self._orig._constructor @@ -2272,7 +2254,7 @@ def findall(self, pat, flags=0): return self._wrap_result(result, returns_string=False) @forbid_nonstring_types(["bytes"]) - def extract(self, pat, flags=0, expand=True): + def extract(self, pat: str, flags: int = 0, expand: bool = True): r""" Extract capture groups in the regex `pat` as columns in a DataFrame. @@ -2353,8 +2335,34 @@ def extract(self, pat, flags=0, expand=True): 2 NaN dtype: object """ - # TODO: dispatch - return str_extract(self, pat, flags, expand=expand) + if not isinstance(expand, bool): + raise ValueError("expand must be True or False") + + regex = re.compile(pat, flags=flags) + if regex.groups == 0: + raise ValueError("pattern contains no capture groups") + + if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex): + raise ValueError("only one regex group is supported with Index") + + result = self._data.array._str_extract(pat, flags, expand) + + returns_df = regex.groups > 1 or expand + name: IndexLabel + if returns_df: + names = {v: k for k, v in regex.groupindex.items()} + name = [names.get(1 + i, i) for i in range(regex.groups)] + else: + name = next(iter(regex.groupindex)) if regex.groupindex else None + + # extract is inconsistent for Indexes when expand is True. To avoid special + # casing _wrap_result we handle that case here + if expand and isinstance(self._data, ABCIndex): + from pandas import DataFrame + + return DataFrame(list(result), columns=name, dtype=object) + + return self._wrap_result(result, name=name, expand=returns_df) @forbid_nonstring_types(["bytes"]) def extractall(self, pat, flags=0): @@ -3004,24 +3012,6 @@ def cat_core(list_of_columns: List, sep: str): return np.sum(arr_with_sep, axis=0) -def _groups_or_na_fun(regex): - """Used in both extract_noexpand and extract_frame""" - if regex.groups == 0: - raise ValueError("pattern contains no capture groups") - empty_row = [np.nan] * regex.groups - - def f(x): - if not isinstance(x, str): - return empty_row - m = regex.search(x) - if m: - return [np.nan if item is None else item for item in m.groups()] - else: - return empty_row - - return f - - def _result_dtype(arr): # workaround #27953 # ideally we just pass `dtype=arr.dtype` unconditionally, but this fails @@ -3034,100 +3024,6 @@ def _result_dtype(arr): return object -def _get_single_group_name(rx): - try: - return list(rx.groupindex.keys()).pop() - except IndexError: - return None - - -def _str_extract_noexpand(arr, pat, flags=0): - """ - Find groups in each string in the Series using passed regular - expression. This function is called from - str_extract(expand=False), and can return Series, DataFrame, or - Index. - - """ - from pandas import ( - DataFrame, - array as pd_array, - ) - - regex = re.compile(pat, flags=flags) - groups_or_na = _groups_or_na_fun(regex) - result_dtype = _result_dtype(arr) - - if regex.groups == 1: - result = np.array([groups_or_na(val)[0] for val in arr], dtype=object) - name = _get_single_group_name(regex) - # not dispatching, so we have to reconstruct here. - result = pd_array(result, dtype=result_dtype) - else: - if isinstance(arr, ABCIndex): - raise ValueError("only one regex group is supported with Index") - name = None - names = dict(zip(regex.groupindex.values(), regex.groupindex.keys())) - columns = [names.get(1 + i, i) for i in range(regex.groups)] - if arr.size == 0: - # error: Incompatible types in assignment (expression has type - # "DataFrame", variable has type "ndarray") - result = DataFrame( # type: ignore[assignment] - columns=columns, dtype=object - ) - else: - dtype = _result_dtype(arr) - # error: Incompatible types in assignment (expression has type - # "DataFrame", variable has type "ndarray") - result = DataFrame( # type:ignore[assignment] - [groups_or_na(val) for val in arr], - columns=columns, - index=arr.index, - dtype=dtype, - ) - return result, name - - -def _str_extract_frame(arr, pat, flags=0): - """ - For each subject string in the Series, extract groups from the - first match of regular expression pat. This function is called from - str_extract(expand=True), and always returns a DataFrame. - - """ - from pandas import DataFrame - - regex = re.compile(pat, flags=flags) - groups_or_na = _groups_or_na_fun(regex) - names = dict(zip(regex.groupindex.values(), regex.groupindex.keys())) - columns = [names.get(1 + i, i) for i in range(regex.groups)] - - if len(arr) == 0: - return DataFrame(columns=columns, dtype=object) - try: - result_index = arr.index - except AttributeError: - result_index = None - dtype = _result_dtype(arr) - return DataFrame( - [groups_or_na(val) for val in arr], - columns=columns, - index=result_index, - dtype=dtype, - ) - - -def str_extract(arr, pat, flags=0, expand=True): - if not isinstance(expand, bool): - raise ValueError("expand must be True or False") - if expand: - result = _str_extract_frame(arr._orig, pat, flags=flags) - return result.__finalize__(arr._orig, method="str_extract") - else: - result, name = _str_extract_noexpand(arr._orig, pat, flags=flags) - return arr._wrap_result(result, name=name, expand=expand) - - def str_extractall(arr, pat, flags=0): regex = re.compile(pat, flags=flags) # the regex must contain capture groups. diff --git a/pandas/core/strings/base.py b/pandas/core/strings/base.py index a77f8861a7c02..9cd0f25c1055a 100644 --- a/pandas/core/strings/base.py +++ b/pandas/core/strings/base.py @@ -222,3 +222,7 @@ def _str_split(self, pat=None, n=-1, expand=False): @abc.abstractmethod def _str_rsplit(self, pat=None, n=-1): pass + + @abc.abstractmethod + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): + pass diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 869eabc76b555..2799dbdbaae22 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -38,7 +38,9 @@ def __len__(self): # For typing, _str_map relies on the object being sized. raise NotImplementedError - def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None): + def _str_map( + self, f, na_value=None, dtype: Optional[Dtype] = None, convert: bool = True + ): """ Map a callable over valid element of the array. @@ -53,6 +55,8 @@ def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None): for object-dtype and Categorical and ``pd.NA`` for StringArray. dtype : Dtype, optional The dtype of the result array. + convert : bool, default True + Whether to call `maybe_convert_objects` on the resulting ndarray """ if dtype is None: dtype = np.dtype("object") @@ -66,9 +70,9 @@ def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None): arr = np.asarray(self, dtype=object) mask = isna(arr) - convert = not np.all(mask) + map_convert = convert and not np.all(mask) try: - result = lib.map_infer_mask(arr, f, mask.view(np.uint8), convert) + result = lib.map_infer_mask(arr, f, mask.view(np.uint8), map_convert) except (TypeError, AttributeError) as e: # Reraise the exception if callable `f` got wrong number of args. # The user may want to be warned by this, instead of getting NaN @@ -94,7 +98,7 @@ def g(x): return result if na_value is not np.nan: np.putmask(result, mask, na_value) - if result.dtype == object: + if convert and result.dtype == object: result = lib.maybe_convert_objects(result) return result @@ -314,7 +318,21 @@ def _str_split(self, pat=None, n=-1, expand=False): n = 0 regex = re.compile(pat) f = lambda x: regex.split(x, maxsplit=n) - return self._str_map(f, dtype=object) + + result = self._str_map(f, dtype=object) + + # propagate nan values to match longest sequence (GH 18450) + if expand: + mask = isna(result) + valid = result[~mask] + if len(valid): + max_len = max(len(x) for x in valid) + na_value = self._str_na_value + empty_row = [na_value] * max_len + for idx in np.argwhere(mask): + result[idx[0]] = empty_row + + return result def _str_rsplit(self, pat=None, n=-1): if n is None or n == 0: @@ -408,3 +426,30 @@ def _str_lstrip(self, to_strip=None): def _str_rstrip(self, to_strip=None): return self._str_map(lambda x: x.rstrip(to_strip)) + + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): + regex = re.compile(pat, flags=flags) + na_value = self._str_na_value + + if regex.groups == 1: + + def mapper(x): + m = regex.search(x) + return m.groups()[0] if m else na_value + + return self._str_map(mapper, convert=False) + else: + empty_row = [self._str_na_value] * regex.groups + + def mapper(x): + m = regex.search(x) + return ( + [na_value if item is None else item for item in m.groups()] + if m + else empty_row + ) + + result = self._str_map(mapper, dtype="object") + for idx in np.argwhere(isna(result)): + result[idx[0]] = empty_row + return result diff --git a/pandas/tests/strings/test_extract.py b/pandas/tests/strings/test_extract.py index c1564a5c256a1..b8e431a33c642 100644 --- a/pandas/tests/strings/test_extract.py +++ b/pandas/tests/strings/test_extract.py @@ -13,30 +13,31 @@ ) -def test_extract_expand_None(): - values = Series(["fooBAD__barBAD", np.nan, "foo"]) +def test_extract_expand_kwarg_wrong_type_raises(any_string_dtype): + # TODO: should this raise TypeError + values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) with pytest.raises(ValueError, match="expand must be True or False"): values.str.extract(".*(BAD[_]+).*(BAD)", expand=None) -def test_extract_expand_unspecified(): - values = Series(["fooBAD__barBAD", np.nan, "foo"]) - result_unspecified = values.str.extract(".*(BAD[_]+).*") - assert isinstance(result_unspecified, DataFrame) - result_true = values.str.extract(".*(BAD[_]+).*", expand=True) - tm.assert_frame_equal(result_unspecified, result_true) +def test_extract_expand_kwarg(any_string_dtype): + s = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) + expected = DataFrame(["BAD__", np.nan, np.nan], dtype=any_string_dtype) + result = s.str.extract(".*(BAD[_]+).*") + tm.assert_frame_equal(result, expected) -def test_extract_expand_False(): - # Contains tests like those in test_match and some others. - values = Series(["fooBAD__barBAD", np.nan, "foo"]) - er = [np.nan, np.nan] # empty row + result = s.str.extract(".*(BAD[_]+).*", expand=True) + tm.assert_frame_equal(result, expected) - result = values.str.extract(".*(BAD[_]+).*(BAD)", expand=False) - exp = DataFrame([["BAD__", "BAD"], er, er]) - tm.assert_frame_equal(result, exp) + expected = DataFrame( + [["BAD__", "BAD"], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + result = s.str.extract(".*(BAD[_]+).*(BAD)", expand=False) + tm.assert_frame_equal(result, expected) - # mixed + +def test_extract_expand_mixed_object(): mixed = Series( [ "aBAD_BAD", @@ -51,47 +52,51 @@ def test_extract_expand_False(): ] ) - rs = Series(mixed).str.extract(".*(BAD[_]+).*(BAD)", expand=False) - exp = DataFrame([["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er]) - tm.assert_frame_equal(rs, exp) - - # unicode - values = Series(["fooBAD__barBAD", np.nan, "foo"]) + result = Series(mixed).str.extract(".*(BAD[_]+).*(BAD)", expand=False) + er = [np.nan, np.nan] # empty row + expected = DataFrame([["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er]) + tm.assert_frame_equal(result, expected) - result = values.str.extract(".*(BAD[_]+).*(BAD)", expand=False) - exp = DataFrame([["BAD__", "BAD"], er, er]) - tm.assert_frame_equal(result, exp) +def test_extract_expand_index_raises(): # GH9980 # Index only works with one regex group since # multi-group would expand to a frame idx = Index(["A1", "A2", "A3", "A4", "B5"]) - with pytest.raises(ValueError, match="supported"): + msg = "only one regex group is supported with Index" + with pytest.raises(ValueError, match=msg): idx.str.extract("([AB])([123])", expand=False) - # these should work for both Series and Index - for klass in [Series, Index]: - # no groups - s_or_idx = klass(["A1", "B2", "C3"]) - msg = "pattern contains no capture groups" - with pytest.raises(ValueError, match=msg): - s_or_idx.str.extract("[ABC][123]", expand=False) - # only non-capturing groups - with pytest.raises(ValueError, match=msg): - s_or_idx.str.extract("(?:[AB]).*", expand=False) +@pytest.mark.parametrize("klass", [Series, Index]) +def test_extract_expand_no_capture_groups_raises(klass, any_string_dtype): + s_or_idx = klass(["A1", "B2", "C3"], dtype=any_string_dtype) + msg = "pattern contains no capture groups" - # single group renames series/index properly - s_or_idx = klass(["A1", "A2"]) - result = s_or_idx.str.extract(r"(?PA)\d", expand=False) - assert result.name == "uno" + # no groups + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("[ABC][123]", expand=False) + + # only non-capturing groups + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("(?:[AB]).*", expand=False) + + +@pytest.mark.parametrize("klass", [Series, Index]) +def test_extract_expand_single_capture_group(klass, any_string_dtype): + # single group renames series/index properly + s_or_idx = klass(["A1", "A2"], dtype=any_string_dtype) + result = s_or_idx.str.extract(r"(?PA)\d", expand=False) + assert result.name == "uno" + + expected = klass(["A", "A"], dtype=any_string_dtype, name="uno") + if klass == Series: + tm.assert_series_equal(result, expected) + else: + tm.assert_index_equal(result, expected) - exp = klass(["A", "A"], name="uno") - if klass == Series: - tm.assert_series_equal(result, exp) - else: - tm.assert_index_equal(result, exp) +def test_extract_expand_capture_groups(): s = Series(["A1", "B2", "C3"]) # one group, no matches result = s.str.extract("(_)", expand=False) @@ -162,7 +167,9 @@ def test_extract_expand_False(): ) tm.assert_frame_equal(result, exp) - # GH6348 + +def test_extract_expand_capture_groups_index(): + # https://github.com/pandas-dev/pandas/issues/6348 # not passing index to the extractor def check_index(index): data = ["A1", "B2", "C"] diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 5d8a63fe481f8..317456fbfcd94 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -175,17 +175,19 @@ def test_empty_str_methods(any_string_dtype): tm.assert_series_equal(empty_str, empty.str.repeat(3)) tm.assert_series_equal(empty_bool, empty.str.match("^a")) tm.assert_frame_equal( - DataFrame(columns=[0], dtype=str), empty.str.extract("()", expand=True) + DataFrame(columns=[0], dtype=any_string_dtype), + empty.str.extract("()", expand=True), ) tm.assert_frame_equal( - DataFrame(columns=[0, 1], dtype=str), empty.str.extract("()()", expand=True) + DataFrame(columns=[0, 1], dtype=any_string_dtype), + empty.str.extract("()()", expand=True), ) tm.assert_series_equal(empty_str, empty.str.extract("()", expand=False)) tm.assert_frame_equal( - DataFrame(columns=[0, 1], dtype=str), + DataFrame(columns=[0, 1], dtype=any_string_dtype), empty.str.extract("()()", expand=False), ) - tm.assert_frame_equal(DataFrame(dtype=str), empty.str.get_dummies()) + tm.assert_frame_equal(DataFrame(dtype=any_string_dtype), empty.str.get_dummies()) tm.assert_series_equal(empty_str, empty_str.str.join("")) tm.assert_series_equal(empty_int, empty.str.len()) tm.assert_series_equal(empty_object, empty_str.str.findall("a")) From c09b664321aaf22c6384aca1151d136b9ca21591 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Wed, 12 May 2021 11:53:13 +0100 Subject: [PATCH 2/4] use _get_single_group_name and _get_group_names --- pandas/core/strings/accessor.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index fa5fd42cba893..48b7881e6bfd0 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -13,7 +13,6 @@ import numpy as np import pandas._libs.lib as lib -from pandas._typing import IndexLabel from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -2352,18 +2351,15 @@ def extract(self, pat: str, flags: int = 0, expand: bool = True): result = self._data.array._str_extract(pat, flags, expand) returns_df = regex.groups > 1 or expand - name: IndexLabel - if returns_df: - names = {v: k for k, v in regex.groupindex.items()} - name = [names.get(1 + i, i) for i in range(regex.groups)] - else: - name = next(iter(regex.groupindex)) if regex.groupindex else None + name = _get_group_names(regex) if returns_df else _get_single_group_name(regex) # extract is inconsistent for Indexes when expand is True. To avoid special # casing _wrap_result we handle that case here if expand and isinstance(self._data, ABCIndex): from pandas import DataFrame + # if expand is True, name is a list of column names + assert isinstance(name, list) # for mypy return DataFrame(list(result), columns=name, dtype=object) return self._wrap_result(result, name=name, expand=returns_df) From 75f9429f61812516ae4fbe33e2708ad8b25dc11e Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sat, 15 May 2021 21:19:38 +0100 Subject: [PATCH 3/4] return 2d object array from _str_extract --- pandas/core/strings/accessor.py | 48 ++++++++++++++++++++++++++--- pandas/core/strings/object_array.py | 18 ++--------- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 2300d2357d0ab..c5dbcdbbbab89 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -24,6 +24,7 @@ is_categorical_dtype, is_integer, is_list_like, + is_object_dtype, is_re, ) from pandas.core.dtypes.generic import ( @@ -264,6 +265,28 @@ def _wrap_result( # infer from ndim if expand is not specified expand = result.ndim != 1 + elif ( + expand is True + and is_object_dtype(result) + and not isinstance(self._orig, ABCIndex) + ): + # required when expand=True is explicitly specified + # not needed when inferred + + def cons_row(x): + if is_list_like(x): + return x + else: + return [x] + + result = [cons_row(x) for x in result] + if result: + # propagate nan values to match longest sequence (GH 18450) + max_len = max(len(x) for x in result) + result = [ + x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result + ] + if not isinstance(expand, bool): raise ValueError("expand must be True or False") @@ -299,14 +322,14 @@ def _wrap_result( index = self._orig.index # This is a mess. dtype: Optional[str] - if not self._is_categorical and returns_string: + if self._is_string and returns_string: dtype = self._orig.dtype else: dtype = None if expand: cons = self._orig._constructor_expanddim - result = cons(list(result), columns=name, index=index, dtype=dtype) + result = cons(result, columns=name, index=index, dtype=dtype) else: # Must be a Series cons = self._orig._constructor @@ -2357,8 +2380,8 @@ def extract( raise ValueError("only one regex group is supported with Index") result = self._data.array._str_extract(pat, flags, expand) - returns_df = regex.groups > 1 or expand + name = _get_group_names(regex) if returns_df else _get_single_group_name(regex) # extract is inconsistent for Indexes when expand is True. To avoid special @@ -2368,9 +2391,24 @@ def extract( # if expand is True, name is a list of column names assert isinstance(name, list) # for mypy - return DataFrame(list(result), columns=name, dtype=object) + return DataFrame(result, columns=name, dtype=object) + + # bypass padding code in _wrap_result + expand_kwarg: Optional[bool] + if returns_df: + if is_object_dtype(result): + if regex.groups == 1: + result = result.reshape(1, -1).T + if result.size == 0: + expand_kwarg = True + else: + expand_kwarg = None + else: + expand_kwarg = True + else: + expand_kwarg = False - return self._wrap_result(result, name=name, expand=returns_df) + return self._wrap_result(result, name=name, expand=expand_kwarg) @forbid_nonstring_types(["bytes"]) def extractall(self, pat, flags=0): diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 2799dbdbaae22..75078155453c5 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -318,21 +318,7 @@ def _str_split(self, pat=None, n=-1, expand=False): n = 0 regex = re.compile(pat) f = lambda x: regex.split(x, maxsplit=n) - - result = self._str_map(f, dtype=object) - - # propagate nan values to match longest sequence (GH 18450) - if expand: - mask = isna(result) - valid = result[~mask] - if len(valid): - max_len = max(len(x) for x in valid) - na_value = self._str_na_value - empty_row = [na_value] * max_len - for idx in np.argwhere(mask): - result[idx[0]] = empty_row - - return result + return self._str_map(f, dtype=object) def _str_rsplit(self, pat=None, n=-1): if n is None or n == 0: @@ -452,4 +438,4 @@ def mapper(x): result = self._str_map(mapper, dtype="object") for idx in np.argwhere(isna(result)): result[idx[0]] = empty_row - return result + return np.array(list(result), dtype=object) From 48dd8d2302d859a3d65b148f4a4021758c94298f Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sun, 16 May 2021 19:08:40 +0100 Subject: [PATCH 4/4] don't use _str_map for 2d --- pandas/_libs/lib.pyi | 1 + pandas/_libs/lib.pyx | 10 ++++-- pandas/core/strings/object_array.py | 52 ++++++++++++++++++++--------- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/pandas/_libs/lib.pyi b/pandas/_libs/lib.pyi index 9dbc47f1d40f7..902d1b7245b46 100644 --- a/pandas/_libs/lib.pyi +++ b/pandas/_libs/lib.pyi @@ -196,6 +196,7 @@ def map_infer_mask( convert: bool = ..., na_value: Any = ..., dtype: np.dtype = ..., + out: np.ndarray = ..., ) -> np.ndarray: ... def indices_fast( diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index e1cb744c7033c..58bd6cdc54b7d 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -2536,7 +2536,8 @@ no_default = NoDefault.no_default # Sentinel indicating the default value. @cython.boundscheck(False) @cython.wraparound(False) def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=True, - object na_value=no_default, cnp.dtype dtype=np.dtype(object) + object na_value=no_default, cnp.dtype dtype=np.dtype(object), + ndarray out=None ) -> np.ndarray: """ Substitute for np.vectorize with pandas-friendly dtype inference. @@ -2554,6 +2555,8 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=Tr input value is used dtype : numpy.dtype The numpy dtype to use for the result ndarray. + out : ndarray + The result. Returns ------- @@ -2565,7 +2568,10 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=Tr object val n = len(arr) - result = np.empty(n, dtype=dtype) + if out is not None: + result = out + else: + result = np.empty(n, dtype=dtype) for i in range(n): if mask[i]: if na_value is no_default: diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 75078155453c5..0f54d5f869973 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -19,6 +19,7 @@ ) from pandas.core.dtypes.common import ( + is_object_dtype, is_re, is_scalar, ) @@ -419,23 +420,44 @@ def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): if regex.groups == 1: - def mapper(x): + def f(x): m = regex.search(x) return m.groups()[0] if m else na_value - return self._str_map(mapper, convert=False) + return self._str_map(f, convert=False) else: - empty_row = [self._str_na_value] * regex.groups + out = np.empty((len(self), regex.groups), dtype=object) - def mapper(x): - m = regex.search(x) - return ( - [na_value if item is None else item for item in m.groups()] - if m - else empty_row - ) - - result = self._str_map(mapper, dtype="object") - for idx in np.argwhere(isna(result)): - result[idx[0]] = empty_row - return np.array(list(result), dtype=object) + if is_object_dtype(self): + + def f(x): + if not isinstance(x, str): + return na_value + m = regex.search(x) + if m: + return [ + na_value if item is None else item for item in m.groups() + ] + else: + return na_value + + else: + + def f(x): + m = regex.search(x) + if m: + return [ + na_value if item is None else item for item in m.groups() + ] + else: + return na_value + + result = lib.map_infer_mask( + np.asarray(self), + f, + mask=isna(self).view("uint8"), + convert=False, + na_value=na_value, + out=out, + ) + return result