Skip to content

Commit 2e3d624

Browse files
BUG: Series.str.extract with StringArray returning object dtype (#41441)
1 parent cf6578a commit 2e3d624

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ Strings
751751

752752
- Bug in the conversion from ``pyarrow.ChunkedArray`` to :class:`~arrays.StringArray` when the original had zero chunks (:issue:`41040`)
753753
- Bug in :meth:`Series.replace` and :meth:`DataFrame.replace` ignoring replacements with ``regex=True`` for ``StringDType`` data (:issue:`41333`, :issue:`35977`)
754+
- Bug in :meth:`Series.str.extract` with :class:`~arrays.StringArray` returning object dtype for empty :class:`DataFrame` (:issue:`41441`)
754755

755756
Interval
756757
^^^^^^^^

pandas/core/strings/accessor.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -3113,17 +3113,16 @@ def _str_extract_noexpand(arr, pat, flags=0):
31133113
# error: Incompatible types in assignment (expression has type
31143114
# "DataFrame", variable has type "ndarray")
31153115
result = DataFrame( # type: ignore[assignment]
3116-
columns=columns, dtype=object
3116+
columns=columns, dtype=result_dtype
31173117
)
31183118
else:
3119-
dtype = _result_dtype(arr)
31203119
# error: Incompatible types in assignment (expression has type
31213120
# "DataFrame", variable has type "ndarray")
31223121
result = DataFrame( # type:ignore[assignment]
31233122
[groups_or_na(val) for val in arr],
31243123
columns=columns,
31253124
index=arr.index,
3126-
dtype=dtype,
3125+
dtype=result_dtype,
31273126
)
31283127
return result, name
31293128

@@ -3140,19 +3139,19 @@ def _str_extract_frame(arr, pat, flags=0):
31403139
regex = re.compile(pat, flags=flags)
31413140
groups_or_na = _groups_or_na_fun(regex)
31423141
columns = _get_group_names(regex)
3142+
result_dtype = _result_dtype(arr)
31433143

31443144
if len(arr) == 0:
3145-
return DataFrame(columns=columns, dtype=object)
3145+
return DataFrame(columns=columns, dtype=result_dtype)
31463146
try:
31473147
result_index = arr.index
31483148
except AttributeError:
31493149
result_index = None
3150-
dtype = _result_dtype(arr)
31513150
return DataFrame(
31523151
[groups_or_na(val) for val in arr],
31533152
columns=columns,
31543153
index=result_index,
3155-
dtype=dtype,
3154+
dtype=result_dtype,
31563155
)
31573156

31583157

pandas/tests/strings/test_strings.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,19 @@ def test_empty_str_methods(any_string_dtype):
175175
tm.assert_series_equal(empty_str, empty.str.repeat(3))
176176
tm.assert_series_equal(empty_bool, empty.str.match("^a"))
177177
tm.assert_frame_equal(
178-
DataFrame(columns=[0], dtype=str), empty.str.extract("()", expand=True)
178+
DataFrame(columns=[0], dtype=any_string_dtype),
179+
empty.str.extract("()", expand=True),
179180
)
180181
tm.assert_frame_equal(
181-
DataFrame(columns=[0, 1], dtype=str), empty.str.extract("()()", expand=True)
182+
DataFrame(columns=[0, 1], dtype=any_string_dtype),
183+
empty.str.extract("()()", expand=True),
182184
)
183185
tm.assert_series_equal(empty_str, empty.str.extract("()", expand=False))
184186
tm.assert_frame_equal(
185-
DataFrame(columns=[0, 1], dtype=str),
187+
DataFrame(columns=[0, 1], dtype=any_string_dtype),
186188
empty.str.extract("()()", expand=False),
187189
)
188-
tm.assert_frame_equal(DataFrame(dtype=str), empty.str.get_dummies())
190+
tm.assert_frame_equal(DataFrame(), empty.str.get_dummies())
189191
tm.assert_series_equal(empty_str, empty_str.str.join(""))
190192
tm.assert_series_equal(empty_int, empty.str.len())
191193
tm.assert_series_equal(empty_object, empty_str.str.findall("a"))

0 commit comments

Comments
 (0)