From fbc8f1edab49ad36cd902ab9192d95b9e730bf0b Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Wed, 12 May 2021 18:42:43 +0100 Subject: [PATCH 1/3] BUG: Series.str.extract with StringArray returning object dtype --- doc/source/whatsnew/v1.3.0.rst | 1 + pandas/core/strings/accessor.py | 11 +++++------ pandas/tests/strings/test_strings.py | 10 ++++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 84f9dae8a0850..8dba343422777 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -750,6 +750,7 @@ Strings - Bug in the conversion from ``pyarrow.ChunkedArray`` to :class:`~arrays.StringArray` when the original had zero chunks (:issue:`41040`) - Bug in :meth:`Series.replace` and :meth:`DataFrame.replace` ignoring replacements with ``regex=True`` for ``StringDType`` data (:issue:`41333`, :issue:`35977`) +- Bug in :meth:`Series.str.extract` with :class:`~arrays.StringArray` returning object dtype for empty :class:`Series` or :class:`DataFrame` (:issue:`XXXXX`) Interval ^^^^^^^^ diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 2646ddfa45b58..025ec232adcb5 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -3108,17 +3108,16 @@ def _str_extract_noexpand(arr, pat, flags=0): # error: Incompatible types in assignment (expression has type # "DataFrame", variable has type "ndarray") result = DataFrame( # type: ignore[assignment] - columns=columns, dtype=object + columns=columns, dtype=result_dtype ) else: - dtype = _result_dtype(arr) # error: Incompatible types in assignment (expression has type # "DataFrame", variable has type "ndarray") result = DataFrame( # type:ignore[assignment] [groups_or_na(val) for val in arr], columns=columns, index=arr.index, - dtype=dtype, + dtype=result_dtype, ) return result, name @@ -3135,19 +3134,19 @@ def _str_extract_frame(arr, pat, flags=0): regex = re.compile(pat, flags=flags) groups_or_na = _groups_or_na_fun(regex) columns = _get_group_names(regex) + result_dtype = _result_dtype(arr) if len(arr) == 0: - return DataFrame(columns=columns, dtype=object) + return DataFrame(columns=columns, dtype=result_dtype) try: result_index = arr.index except AttributeError: result_index = None - dtype = _result_dtype(arr) return DataFrame( [groups_or_na(val) for val in arr], columns=columns, index=result_index, - dtype=dtype, + dtype=result_dtype, ) diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 5d8a63fe481f8..317456fbfcd94 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -175,17 +175,19 @@ def test_empty_str_methods(any_string_dtype): tm.assert_series_equal(empty_str, empty.str.repeat(3)) tm.assert_series_equal(empty_bool, empty.str.match("^a")) tm.assert_frame_equal( - DataFrame(columns=[0], dtype=str), empty.str.extract("()", expand=True) + DataFrame(columns=[0], dtype=any_string_dtype), + empty.str.extract("()", expand=True), ) tm.assert_frame_equal( - DataFrame(columns=[0, 1], dtype=str), empty.str.extract("()()", expand=True) + DataFrame(columns=[0, 1], dtype=any_string_dtype), + empty.str.extract("()()", expand=True), ) tm.assert_series_equal(empty_str, empty.str.extract("()", expand=False)) tm.assert_frame_equal( - DataFrame(columns=[0, 1], dtype=str), + DataFrame(columns=[0, 1], dtype=any_string_dtype), empty.str.extract("()()", expand=False), ) - tm.assert_frame_equal(DataFrame(dtype=str), empty.str.get_dummies()) + tm.assert_frame_equal(DataFrame(dtype=any_string_dtype), empty.str.get_dummies()) tm.assert_series_equal(empty_str, empty_str.str.join("")) tm.assert_series_equal(empty_int, empty.str.len()) tm.assert_series_equal(empty_object, empty_str.str.findall("a")) From 63df8e421074f34495852038c875676642923426 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Wed, 12 May 2021 18:47:56 +0100 Subject: [PATCH 2/3] add issue number --- doc/source/whatsnew/v1.3.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 8dba343422777..6b9b9b75cf048 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -750,7 +750,7 @@ Strings - Bug in the conversion from ``pyarrow.ChunkedArray`` to :class:`~arrays.StringArray` when the original had zero chunks (:issue:`41040`) - Bug in :meth:`Series.replace` and :meth:`DataFrame.replace` ignoring replacements with ``regex=True`` for ``StringDType`` data (:issue:`41333`, :issue:`35977`) -- Bug in :meth:`Series.str.extract` with :class:`~arrays.StringArray` returning object dtype for empty :class:`Series` or :class:`DataFrame` (:issue:`XXXXX`) +- Bug in :meth:`Series.str.extract` with :class:`~arrays.StringArray` returning object dtype for empty :class:`Series` or :class:`DataFrame` (:issue:`41441`) Interval ^^^^^^^^ From c507cfa6f8d4ec8f844c753a3d532c107fa512c3 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Thu, 13 May 2021 08:01:37 +0100 Subject: [PATCH 3/3] update --- doc/source/whatsnew/v1.3.0.rst | 2 +- pandas/tests/strings/test_strings.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 6b9b9b75cf048..73a6360c361db 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -750,7 +750,7 @@ Strings - Bug in the conversion from ``pyarrow.ChunkedArray`` to :class:`~arrays.StringArray` when the original had zero chunks (:issue:`41040`) - Bug in :meth:`Series.replace` and :meth:`DataFrame.replace` ignoring replacements with ``regex=True`` for ``StringDType`` data (:issue:`41333`, :issue:`35977`) -- Bug in :meth:`Series.str.extract` with :class:`~arrays.StringArray` returning object dtype for empty :class:`Series` or :class:`DataFrame` (:issue:`41441`) +- Bug in :meth:`Series.str.extract` with :class:`~arrays.StringArray` returning object dtype for empty :class:`DataFrame` (:issue:`41441`) Interval ^^^^^^^^ diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 317456fbfcd94..a18d54b4de44d 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -187,7 +187,7 @@ def test_empty_str_methods(any_string_dtype): DataFrame(columns=[0, 1], dtype=any_string_dtype), empty.str.extract("()()", expand=False), ) - tm.assert_frame_equal(DataFrame(dtype=any_string_dtype), empty.str.get_dummies()) + tm.assert_frame_equal(DataFrame(), empty.str.get_dummies()) tm.assert_series_equal(empty_str, empty_str.str.join("")) tm.assert_series_equal(empty_int, empty.str.len()) tm.assert_series_equal(empty_object, empty_str.str.findall("a"))