Skip to content

Commit a61c556

Browse files
authored
Adjust tests in strings folder for new string option (#56159)
* Adjust tests in strings folder for new string option * BUG: translate losing object dtype with new string dtype * Fix * BUG: Index.str.cat casting result always to object * Update accessor.py * Fix further bugs * Fix * Fix tests * Update accessor.py
1 parent 91ddc8b commit a61c556

8 files changed

+97
-46
lines changed

pandas/core/strings/accessor.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,13 @@ def split(
918918
if is_re(pat):
919919
regex = True
920920
result = self._data.array._str_split(pat, n, expand, regex)
921-
return self._wrap_result(result, returns_string=expand, expand=expand)
921+
if self._data.dtype == "category":
922+
dtype = self._data.dtype.categories.dtype
923+
else:
924+
dtype = object if self._data.dtype == object else None
925+
return self._wrap_result(
926+
result, expand=expand, returns_string=expand, dtype=dtype
927+
)
922928

923929
@Appender(
924930
_shared_docs["str_split"]
@@ -936,7 +942,10 @@ def split(
936942
@forbid_nonstring_types(["bytes"])
937943
def rsplit(self, pat=None, *, n=-1, expand: bool = False):
938944
result = self._data.array._str_rsplit(pat, n=n)
939-
return self._wrap_result(result, expand=expand, returns_string=expand)
945+
dtype = object if self._data.dtype == object else None
946+
return self._wrap_result(
947+
result, expand=expand, returns_string=expand, dtype=dtype
948+
)
940949

941950
_shared_docs[
942951
"str_partition"
@@ -1032,7 +1041,13 @@ def rsplit(self, pat=None, *, n=-1, expand: bool = False):
10321041
@forbid_nonstring_types(["bytes"])
10331042
def partition(self, sep: str = " ", expand: bool = True):
10341043
result = self._data.array._str_partition(sep, expand)
1035-
return self._wrap_result(result, expand=expand, returns_string=expand)
1044+
if self._data.dtype == "category":
1045+
dtype = self._data.dtype.categories.dtype
1046+
else:
1047+
dtype = object if self._data.dtype == object else None
1048+
return self._wrap_result(
1049+
result, expand=expand, returns_string=expand, dtype=dtype
1050+
)
10361051

10371052
@Appender(
10381053
_shared_docs["str_partition"]
@@ -1046,7 +1061,13 @@ def partition(self, sep: str = " ", expand: bool = True):
10461061
@forbid_nonstring_types(["bytes"])
10471062
def rpartition(self, sep: str = " ", expand: bool = True):
10481063
result = self._data.array._str_rpartition(sep, expand)
1049-
return self._wrap_result(result, expand=expand, returns_string=expand)
1064+
if self._data.dtype == "category":
1065+
dtype = self._data.dtype.categories.dtype
1066+
else:
1067+
dtype = object if self._data.dtype == object else None
1068+
return self._wrap_result(
1069+
result, expand=expand, returns_string=expand, dtype=dtype
1070+
)
10501071

10511072
def get(self, i):
10521073
"""
@@ -2752,7 +2773,7 @@ def extract(
27522773
else:
27532774
name = _get_single_group_name(regex)
27542775
result = self._data.array._str_extract(pat, flags=flags, expand=returns_df)
2755-
return self._wrap_result(result, name=name)
2776+
return self._wrap_result(result, name=name, dtype=result_dtype)
27562777

27572778
@forbid_nonstring_types(["bytes"])
27582779
def extractall(self, pat, flags: int = 0) -> DataFrame:
@@ -3492,7 +3513,7 @@ def str_extractall(arr, pat, flags: int = 0) -> DataFrame:
34923513
raise ValueError("pattern contains no capture groups")
34933514

34943515
if isinstance(arr, ABCIndex):
3495-
arr = arr.to_series().reset_index(drop=True)
3516+
arr = arr.to_series().reset_index(drop=True).astype(arr.dtype)
34963517

34973518
columns = _get_group_names(regex)
34983519
match_list = []

pandas/tests/strings/test_api.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
MultiIndex,
99
Series,
1010
_testing as tm,
11+
option_context,
1112
)
1213
from pandas.core.strings.accessor import StringMethods
1314

@@ -163,7 +164,8 @@ def test_api_per_method(
163164

164165
if inferred_dtype in allowed_types:
165166
# xref GH 23555, GH 23556
166-
method(*args, **kwargs) # works!
167+
with option_context("future.no_silent_downcasting", True):
168+
method(*args, **kwargs) # works!
167169
else:
168170
# GH 23011, GH 23163
169171
msg = (

pandas/tests/strings/test_case_justify.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def test_title_mixed_object():
2121
s = Series(["FOO", np.nan, "bar", True, datetime.today(), "blah", None, 1, 2.0])
2222
result = s.str.title()
2323
expected = Series(
24-
["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan]
24+
["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan],
25+
dtype=object,
2526
)
2627
tm.assert_almost_equal(result, expected)
2728

@@ -41,11 +42,15 @@ def test_lower_upper_mixed_object():
4142
s = Series(["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0])
4243

4344
result = s.str.upper()
44-
expected = Series(["A", np.nan, "B", np.nan, np.nan, "FOO", None, np.nan, np.nan])
45+
expected = Series(
46+
["A", np.nan, "B", np.nan, np.nan, "FOO", None, np.nan, np.nan], dtype=object
47+
)
4548
tm.assert_series_equal(result, expected)
4649

4750
result = s.str.lower()
48-
expected = Series(["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan])
51+
expected = Series(
52+
["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object
53+
)
4954
tm.assert_series_equal(result, expected)
5055

5156

@@ -71,7 +76,8 @@ def test_capitalize_mixed_object():
7176
s = Series(["FOO", np.nan, "bar", True, datetime.today(), "blah", None, 1, 2.0])
7277
result = s.str.capitalize()
7378
expected = Series(
74-
["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan]
79+
["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan],
80+
dtype=object,
7581
)
7682
tm.assert_series_equal(result, expected)
7783

@@ -87,7 +93,8 @@ def test_swapcase_mixed_object():
8793
s = Series(["FOO", np.nan, "bar", True, datetime.today(), "Blah", None, 1, 2.0])
8894
result = s.str.swapcase()
8995
expected = Series(
90-
["foo", np.nan, "BAR", np.nan, np.nan, "bLAH", None, np.nan, np.nan]
96+
["foo", np.nan, "BAR", np.nan, np.nan, "bLAH", None, np.nan, np.nan],
97+
dtype=object,
9198
)
9299
tm.assert_series_equal(result, expected)
93100

@@ -138,19 +145,22 @@ def test_pad_mixed_object():
138145

139146
result = s.str.pad(5, side="left")
140147
expected = Series(
141-
[" a", np.nan, " b", np.nan, np.nan, " ee", None, np.nan, np.nan]
148+
[" a", np.nan, " b", np.nan, np.nan, " ee", None, np.nan, np.nan],
149+
dtype=object,
142150
)
143151
tm.assert_series_equal(result, expected)
144152

145153
result = s.str.pad(5, side="right")
146154
expected = Series(
147-
["a ", np.nan, "b ", np.nan, np.nan, "ee ", None, np.nan, np.nan]
155+
["a ", np.nan, "b ", np.nan, np.nan, "ee ", None, np.nan, np.nan],
156+
dtype=object,
148157
)
149158
tm.assert_series_equal(result, expected)
150159

151160
result = s.str.pad(5, side="both")
152161
expected = Series(
153-
[" a ", np.nan, " b ", np.nan, np.nan, " ee ", None, np.nan, np.nan]
162+
[" a ", np.nan, " b ", np.nan, np.nan, " ee ", None, np.nan, np.nan],
163+
dtype=object,
154164
)
155165
tm.assert_series_equal(result, expected)
156166

@@ -238,7 +248,8 @@ def test_center_ljust_rjust_mixed_object():
238248
None,
239249
np.nan,
240250
np.nan,
241-
]
251+
],
252+
dtype=object,
242253
)
243254
tm.assert_series_equal(result, expected)
244255

@@ -255,7 +266,8 @@ def test_center_ljust_rjust_mixed_object():
255266
None,
256267
np.nan,
257268
np.nan,
258-
]
269+
],
270+
dtype=object,
259271
)
260272
tm.assert_series_equal(result, expected)
261273

@@ -272,7 +284,8 @@ def test_center_ljust_rjust_mixed_object():
272284
None,
273285
np.nan,
274286
np.nan,
275-
]
287+
],
288+
dtype=object,
276289
)
277290
tm.assert_series_equal(result, expected)
278291

pandas/tests/strings/test_extract.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,16 @@ def test_extract_expand_False_mixed_object():
4747
# two groups
4848
result = ser.str.extract(".*(BAD[_]+).*(BAD)", expand=False)
4949
er = [np.nan, np.nan] # empty row
50-
expected = DataFrame([["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er])
50+
expected = DataFrame(
51+
[["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er], dtype=object
52+
)
5153
tm.assert_frame_equal(result, expected)
5254

5355
# single group
5456
result = ser.str.extract(".*(BAD[_]+).*BAD", expand=False)
5557
expected = Series(
56-
["BAD_", np.nan, "BAD_", np.nan, np.nan, np.nan, None, np.nan, np.nan]
58+
["BAD_", np.nan, "BAD_", np.nan, np.nan, np.nan, None, np.nan, np.nan],
59+
dtype=object,
5760
)
5861
tm.assert_series_equal(result, expected)
5962

@@ -238,7 +241,9 @@ def test_extract_expand_True_mixed_object():
238241
)
239242

240243
result = mixed.str.extract(".*(BAD[_]+).*(BAD)", expand=True)
241-
expected = DataFrame([["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er])
244+
expected = DataFrame(
245+
[["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er], dtype=object
246+
)
242247
tm.assert_frame_equal(result, expected)
243248

244249

@@ -603,8 +608,8 @@ def test_extractall_stringindex(any_string_dtype):
603608
# index.name doesn't affect to the result
604609
if any_string_dtype == "object":
605610
for idx in [
606-
Index(["a1a2", "b1", "c1"]),
607-
Index(["a1a2", "b1", "c1"], name="xxx"),
611+
Index(["a1a2", "b1", "c1"], dtype=object),
612+
Index(["a1a2", "b1", "c1"], name="xxx", dtype=object),
608613
]:
609614
result = idx.str.extractall(r"[ab](?P<digit>\d)")
610615
tm.assert_frame_equal(result, expected)

pandas/tests/strings/test_find_replace.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def test_contains_nan(any_string_dtype):
242242

243243

244244
@pytest.mark.parametrize("pat", ["foo", ("foo", "baz")])
245-
@pytest.mark.parametrize("dtype", [None, "category"])
245+
@pytest.mark.parametrize("dtype", ["object", "category"])
246246
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
247247
@pytest.mark.parametrize("na", [True, False])
248248
def test_startswith(pat, dtype, null_value, na):
@@ -254,10 +254,10 @@ def test_startswith(pat, dtype, null_value, na):
254254

255255
result = values.str.startswith(pat)
256256
exp = Series([False, np.nan, True, False, False, np.nan, True])
257-
if dtype is None and null_value is pd.NA:
257+
if dtype == "object" and null_value is pd.NA:
258258
# GH#18463
259259
exp = exp.fillna(null_value)
260-
elif dtype is None and null_value is None:
260+
elif dtype == "object" and null_value is None:
261261
exp[exp.isna()] = None
262262
tm.assert_series_equal(result, exp)
263263

@@ -300,7 +300,7 @@ def test_startswith_nullable_string_dtype(nullable_string_dtype, na):
300300

301301

302302
@pytest.mark.parametrize("pat", ["foo", ("foo", "baz")])
303-
@pytest.mark.parametrize("dtype", [None, "category"])
303+
@pytest.mark.parametrize("dtype", ["object", "category"])
304304
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
305305
@pytest.mark.parametrize("na", [True, False])
306306
def test_endswith(pat, dtype, null_value, na):
@@ -312,10 +312,10 @@ def test_endswith(pat, dtype, null_value, na):
312312

313313
result = values.str.endswith(pat)
314314
exp = Series([False, np.nan, False, False, True, np.nan, True])
315-
if dtype is None and null_value is pd.NA:
315+
if dtype == "object" and null_value is pd.NA:
316316
# GH#18463
317-
exp = exp.fillna(pd.NA)
318-
elif dtype is None and null_value is None:
317+
exp = exp.fillna(null_value)
318+
elif dtype == "object" and null_value is None:
319319
exp[exp.isna()] = None
320320
tm.assert_series_equal(result, exp)
321321

@@ -382,7 +382,9 @@ def test_replace_mixed_object():
382382
["aBAD", np.nan, "bBAD", True, datetime.today(), "fooBAD", None, 1, 2.0]
383383
)
384384
result = Series(ser).str.replace("BAD[_]*", "", regex=True)
385-
expected = Series(["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan])
385+
expected = Series(
386+
["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object
387+
)
386388
tm.assert_series_equal(result, expected)
387389

388390

@@ -469,7 +471,9 @@ def test_replace_compiled_regex_mixed_object():
469471
["aBAD", np.nan, "bBAD", True, datetime.today(), "fooBAD", None, 1, 2.0]
470472
)
471473
result = Series(ser).str.replace(pat, "", regex=True)
472-
expected = Series(["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan])
474+
expected = Series(
475+
["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object
476+
)
473477
tm.assert_series_equal(result, expected)
474478

475479

@@ -913,7 +917,7 @@ def test_translate_mixed_object():
913917
# Series with non-string values
914918
s = Series(["a", "b", "c", 1.2])
915919
table = str.maketrans("abc", "cde")
916-
expected = Series(["c", "d", "e", np.nan])
920+
expected = Series(["c", "d", "e", np.nan], dtype=object)
917921
result = s.str.translate(table)
918922
tm.assert_series_equal(result, expected)
919923

pandas/tests/strings/test_split_partition.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -681,22 +681,24 @@ def test_partition_sep_kwarg(any_string_dtype, method):
681681
def test_get():
682682
ser = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"])
683683
result = ser.str.split("_").str.get(1)
684-
expected = Series(["b", "d", np.nan, "g"])
684+
expected = Series(["b", "d", np.nan, "g"], dtype=object)
685685
tm.assert_series_equal(result, expected)
686686

687687

688688
def test_get_mixed_object():
689689
ser = Series(["a_b_c", np.nan, "c_d_e", True, datetime.today(), None, 1, 2.0])
690690
result = ser.str.split("_").str.get(1)
691-
expected = Series(["b", np.nan, "d", np.nan, np.nan, None, np.nan, np.nan])
691+
expected = Series(
692+
["b", np.nan, "d", np.nan, np.nan, None, np.nan, np.nan], dtype=object
693+
)
692694
tm.assert_series_equal(result, expected)
693695

694696

695697
@pytest.mark.parametrize("idx", [2, -3])
696698
def test_get_bounds(idx):
697699
ser = Series(["1_2_3_4_5", "6_7_8_9_10", "11_12"])
698700
result = ser.str.split("_").str.get(idx)
699-
expected = Series(["3", "8", np.nan])
701+
expected = Series(["3", "8", np.nan], dtype=object)
700702
tm.assert_series_equal(result, expected)
701703

702704

pandas/tests/strings/test_string_array.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
DataFrame,
99
Series,
1010
_testing as tm,
11+
option_context,
1112
)
1213

1314

@@ -56,7 +57,8 @@ def test_string_array(nullable_string_dtype, any_string_method):
5657
columns = expected.select_dtypes(include="object").columns
5758
assert all(result[columns].dtypes == nullable_string_dtype)
5859
result[columns] = result[columns].astype(object)
59-
expected[columns] = expected[columns].fillna(NA) # GH#18463
60+
with option_context("future.no_silent_downcasting", True):
61+
expected[columns] = expected[columns].fillna(NA) # GH#18463
6062

6163
tm.assert_equal(result, expected)
6264

0 commit comments

Comments
 (0)