diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 84f9dae8a0850..73a6360c361db 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -750,6 +750,7 @@ Strings - Bug in the conversion from ``pyarrow.ChunkedArray`` to :class:`~arrays.StringArray` when the original had zero chunks (:issue:`41040`) - Bug in :meth:`Series.replace` and :meth:`DataFrame.replace` ignoring replacements with ``regex=True`` for ``StringDType`` data (:issue:`41333`, :issue:`35977`) +- Bug in :meth:`Series.str.extract` with :class:`~arrays.StringArray` returning object dtype for empty :class:`DataFrame` (:issue:`41441`) Interval ^^^^^^^^ diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 2646ddfa45b58..025ec232adcb5 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -3108,17 +3108,16 @@ def _str_extract_noexpand(arr, pat, flags=0): # error: Incompatible types in assignment (expression has type # "DataFrame", variable has type "ndarray") result = DataFrame( # type: ignore[assignment] - columns=columns, dtype=object + columns=columns, dtype=result_dtype ) else: - dtype = _result_dtype(arr) # error: Incompatible types in assignment (expression has type # "DataFrame", variable has type "ndarray") result = DataFrame( # type:ignore[assignment] [groups_or_na(val) for val in arr], columns=columns, index=arr.index, - dtype=dtype, + dtype=result_dtype, ) return result, name @@ -3135,19 +3134,19 @@ def _str_extract_frame(arr, pat, flags=0): regex = re.compile(pat, flags=flags) groups_or_na = _groups_or_na_fun(regex) columns = _get_group_names(regex) + result_dtype = _result_dtype(arr) if len(arr) == 0: - return DataFrame(columns=columns, dtype=object) + return DataFrame(columns=columns, dtype=result_dtype) try: result_index = arr.index except AttributeError: result_index = None - dtype = _result_dtype(arr) return DataFrame( [groups_or_na(val) for val in arr], columns=columns, index=result_index, - dtype=dtype, + dtype=result_dtype, ) diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 5d8a63fe481f8..a18d54b4de44d 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -175,17 +175,19 @@ def test_empty_str_methods(any_string_dtype): tm.assert_series_equal(empty_str, empty.str.repeat(3)) tm.assert_series_equal(empty_bool, empty.str.match("^a")) tm.assert_frame_equal( - DataFrame(columns=[0], dtype=str), empty.str.extract("()", expand=True) + DataFrame(columns=[0], dtype=any_string_dtype), + empty.str.extract("()", expand=True), ) tm.assert_frame_equal( - DataFrame(columns=[0, 1], dtype=str), empty.str.extract("()()", expand=True) + DataFrame(columns=[0, 1], dtype=any_string_dtype), + empty.str.extract("()()", expand=True), ) tm.assert_series_equal(empty_str, empty.str.extract("()", expand=False)) tm.assert_frame_equal( - DataFrame(columns=[0, 1], dtype=str), + DataFrame(columns=[0, 1], dtype=any_string_dtype), empty.str.extract("()()", expand=False), ) - tm.assert_frame_equal(DataFrame(dtype=str), empty.str.get_dummies()) + tm.assert_frame_equal(DataFrame(), empty.str.get_dummies()) tm.assert_series_equal(empty_str, empty_str.str.join("")) tm.assert_series_equal(empty_int, empty.str.len()) tm.assert_series_equal(empty_object, empty_str.str.findall("a"))