Skip to content

Commit 59df6a8

Browse files
[ArrowStringArray] TST: parametrize str.split tests (#41392)
1 parent 7f07928 commit 59df6a8

File tree

2 files changed

+133
-83
lines changed

2 files changed

+133
-83
lines changed

asv_bench/benchmarks/strings.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,21 @@ def time_contains(self, dtype, regex):
230230

231231
class Split:
232232

233-
params = [True, False]
234-
param_names = ["expand"]
233+
params = (["str", "string", "arrow_string"], [True, False])
234+
param_names = ["dtype", "expand"]
235+
236+
def setup(self, dtype, expand):
237+
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
235238

236-
def setup(self, expand):
237-
self.s = Series(tm.makeStringIndex(10 ** 5)).str.join("--")
239+
try:
240+
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype).str.join("--")
241+
except ImportError:
242+
raise NotImplementedError
238243

239-
def time_split(self, expand):
244+
def time_split(self, dtype, expand):
240245
self.s.str.split("--", expand=expand)
241246

242-
def time_rsplit(self, expand):
247+
def time_rsplit(self, dtype, expand):
243248
self.s.str.rsplit("--", expand=expand)
244249

245250

pandas/tests/strings/test_split_partition.py

+122-77
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,29 @@
1313
)
1414

1515

16-
def test_split():
17-
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"])
16+
def test_split(any_string_dtype):
17+
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype)
1818

1919
result = values.str.split("_")
2020
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
2121
tm.assert_series_equal(result, exp)
2222

2323
# more than one char
24-
values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"])
24+
values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"], dtype=any_string_dtype)
2525
result = values.str.split("__")
2626
tm.assert_series_equal(result, exp)
2727

2828
result = values.str.split("__", expand=False)
2929
tm.assert_series_equal(result, exp)
3030

31-
# mixed
31+
# regex split
32+
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
33+
result = values.str.split("[,_]")
34+
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
35+
tm.assert_series_equal(result, exp)
36+
37+
38+
def test_split_object_mixed():
3239
mixed = Series(["a_b_c", np.nan, "d_e_f", True, datetime.today(), None, 1, 2.0])
3340
result = mixed.str.split("_")
3441
exp = Series(
@@ -50,17 +57,10 @@ def test_split():
5057
assert isinstance(result, Series)
5158
tm.assert_almost_equal(result, exp)
5259

53-
# regex split
54-
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"])
55-
result = values.str.split("[,_]")
56-
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
57-
tm.assert_series_equal(result, exp)
58-
5960

60-
@pytest.mark.parametrize("dtype", [object, "string"])
6161
@pytest.mark.parametrize("method", ["split", "rsplit"])
62-
def test_split_n(dtype, method):
63-
s = Series(["a b", pd.NA, "b c"], dtype=dtype)
62+
def test_split_n(any_string_dtype, method):
63+
s = Series(["a b", pd.NA, "b c"], dtype=any_string_dtype)
6464
expected = Series([["a", "b"], pd.NA, ["b", "c"]])
6565

6666
result = getattr(s.str, method)(" ", n=None)
@@ -70,20 +70,34 @@ def test_split_n(dtype, method):
7070
tm.assert_series_equal(result, expected)
7171

7272

73-
def test_rsplit():
74-
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"])
73+
def test_rsplit(any_string_dtype):
74+
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype)
7575
result = values.str.rsplit("_")
7676
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
7777
tm.assert_series_equal(result, exp)
7878

7979
# more than one char
80-
values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"])
80+
values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"], dtype=any_string_dtype)
8181
result = values.str.rsplit("__")
8282
tm.assert_series_equal(result, exp)
8383

8484
result = values.str.rsplit("__", expand=False)
8585
tm.assert_series_equal(result, exp)
8686

87+
# regex split is not supported by rsplit
88+
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
89+
result = values.str.rsplit("[,_]")
90+
exp = Series([["a,b_c"], ["c_d,e"], np.nan, ["f,g,h"]])
91+
tm.assert_series_equal(result, exp)
92+
93+
# setting max number of splits, make sure it's from reverse
94+
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype)
95+
result = values.str.rsplit("_", n=1)
96+
exp = Series([["a_b", "c"], ["c_d", "e"], np.nan, ["f_g", "h"]])
97+
tm.assert_series_equal(result, exp)
98+
99+
100+
def test_rsplit_object_mixed():
87101
# mixed
88102
mixed = Series(["a_b_c", np.nan, "d_e_f", True, datetime.today(), None, 1, 2.0])
89103
result = mixed.str.rsplit("_")
@@ -106,87 +120,96 @@ def test_rsplit():
106120
assert isinstance(result, Series)
107121
tm.assert_almost_equal(result, exp)
108122

109-
# regex split is not supported by rsplit
110-
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"])
111-
result = values.str.rsplit("[,_]")
112-
exp = Series([["a,b_c"], ["c_d,e"], np.nan, ["f,g,h"]])
113-
tm.assert_series_equal(result, exp)
114123

115-
# setting max number of splits, make sure it's from reverse
116-
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"])
117-
result = values.str.rsplit("_", n=1)
118-
exp = Series([["a_b", "c"], ["c_d", "e"], np.nan, ["f_g", "h"]])
119-
tm.assert_series_equal(result, exp)
120-
121-
122-
def test_split_blank_string():
124+
def test_split_blank_string(any_string_dtype):
123125
# expand blank split GH 20067
124-
values = Series([""], name="test")
126+
values = Series([""], name="test", dtype=any_string_dtype)
125127
result = values.str.split(expand=True)
126-
exp = DataFrame([[]]) # NOTE: this is NOT an empty DataFrame
128+
exp = DataFrame([[]], dtype=any_string_dtype) # NOTE: this is NOT an empty df
127129
tm.assert_frame_equal(result, exp)
128130

129-
values = Series(["a b c", "a b", "", " "], name="test")
131+
values = Series(["a b c", "a b", "", " "], name="test", dtype=any_string_dtype)
130132
result = values.str.split(expand=True)
131133
exp = DataFrame(
132134
[
133135
["a", "b", "c"],
134136
["a", "b", np.nan],
135137
[np.nan, np.nan, np.nan],
136138
[np.nan, np.nan, np.nan],
137-
]
139+
],
140+
dtype=any_string_dtype,
138141
)
139142
tm.assert_frame_equal(result, exp)
140143

141144

142-
def test_split_noargs():
145+
def test_split_noargs(any_string_dtype):
143146
# #1859
144-
s = Series(["Wes McKinney", "Travis Oliphant"])
147+
s = Series(["Wes McKinney", "Travis Oliphant"], dtype=any_string_dtype)
145148
result = s.str.split()
146149
expected = ["Travis", "Oliphant"]
147150
assert result[1] == expected
148151
result = s.str.rsplit()
149152
assert result[1] == expected
150153

151154

152-
def test_split_maxsplit():
155+
@pytest.mark.parametrize(
156+
"data, pat",
157+
[
158+
(["bd asdf jfg", "kjasdflqw asdfnfk"], None),
159+
(["bd asdf jfg", "kjasdflqw asdfnfk"], "asdf"),
160+
(["bd_asdf_jfg", "kjasdflqw_asdfnfk"], "_"),
161+
],
162+
)
163+
def test_split_maxsplit(data, pat, any_string_dtype):
153164
# re.split 0, str.split -1
154-
s = Series(["bd asdf jfg", "kjasdflqw asdfnfk"])
155-
156-
result = s.str.split(n=-1)
157-
xp = s.str.split()
158-
tm.assert_series_equal(result, xp)
165+
s = Series(data, dtype=any_string_dtype)
159166

160-
result = s.str.split(n=0)
167+
result = s.str.split(pat=pat, n=-1)
168+
xp = s.str.split(pat=pat)
161169
tm.assert_series_equal(result, xp)
162170

163-
xp = s.str.split("asdf")
164-
result = s.str.split("asdf", n=0)
171+
result = s.str.split(pat=pat, n=0)
165172
tm.assert_series_equal(result, xp)
166173

167-
result = s.str.split("asdf", n=-1)
168-
tm.assert_series_equal(result, xp)
169174

170-
171-
def test_split_no_pat_with_nonzero_n():
172-
s = Series(["split once", "split once too!"])
173-
result = s.str.split(n=1)
174-
expected = Series({0: ["split", "once"], 1: ["split", "once too!"]})
175+
@pytest.mark.parametrize(
176+
"data, pat, expected",
177+
[
178+
(
179+
["split once", "split once too!"],
180+
None,
181+
Series({0: ["split", "once"], 1: ["split", "once too!"]}),
182+
),
183+
(
184+
["split_once", "split_once_too!"],
185+
"_",
186+
Series({0: ["split", "once"], 1: ["split", "once_too!"]}),
187+
),
188+
],
189+
)
190+
def test_split_no_pat_with_nonzero_n(data, pat, expected, any_string_dtype):
191+
s = Series(data, dtype=any_string_dtype)
192+
result = s.str.split(pat=pat, n=1)
175193
tm.assert_series_equal(expected, result, check_index_type=False)
176194

177195

178-
def test_split_to_dataframe():
179-
s = Series(["nosplit", "alsonosplit"])
196+
def test_split_to_dataframe(any_string_dtype):
197+
s = Series(["nosplit", "alsonosplit"], dtype=any_string_dtype)
180198
result = s.str.split("_", expand=True)
181-
exp = DataFrame({0: Series(["nosplit", "alsonosplit"])})
199+
exp = DataFrame({0: Series(["nosplit", "alsonosplit"], dtype=any_string_dtype)})
182200
tm.assert_frame_equal(result, exp)
183201

184-
s = Series(["some_equal_splits", "with_no_nans"])
202+
s = Series(["some_equal_splits", "with_no_nans"], dtype=any_string_dtype)
185203
result = s.str.split("_", expand=True)
186-
exp = DataFrame({0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]})
204+
exp = DataFrame(
205+
{0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]},
206+
dtype=any_string_dtype,
207+
)
187208
tm.assert_frame_equal(result, exp)
188209

189-
s = Series(["some_unequal_splits", "one_of_these_things_is_not"])
210+
s = Series(
211+
["some_unequal_splits", "one_of_these_things_is_not"], dtype=any_string_dtype
212+
)
190213
result = s.str.split("_", expand=True)
191214
exp = DataFrame(
192215
{
@@ -196,14 +219,19 @@ def test_split_to_dataframe():
196219
3: [np.nan, "things"],
197220
4: [np.nan, "is"],
198221
5: [np.nan, "not"],
199-
}
222+
},
223+
dtype=any_string_dtype,
200224
)
201225
tm.assert_frame_equal(result, exp)
202226

203-
s = Series(["some_splits", "with_index"], index=["preserve", "me"])
227+
s = Series(
228+
["some_splits", "with_index"], index=["preserve", "me"], dtype=any_string_dtype
229+
)
204230
result = s.str.split("_", expand=True)
205231
exp = DataFrame(
206-
{0: ["some", "with"], 1: ["splits", "index"]}, index=["preserve", "me"]
232+
{0: ["some", "with"], 1: ["splits", "index"]},
233+
index=["preserve", "me"],
234+
dtype=any_string_dtype,
207235
)
208236
tm.assert_frame_equal(result, exp)
209237

@@ -250,29 +278,41 @@ def test_split_to_multiindex_expand():
250278
idx.str.split("_", expand="not_a_boolean")
251279

252280

253-
def test_rsplit_to_dataframe_expand():
254-
s = Series(["nosplit", "alsonosplit"])
281+
def test_rsplit_to_dataframe_expand(any_string_dtype):
282+
s = Series(["nosplit", "alsonosplit"], dtype=any_string_dtype)
255283
result = s.str.rsplit("_", expand=True)
256-
exp = DataFrame({0: Series(["nosplit", "alsonosplit"])})
284+
exp = DataFrame({0: Series(["nosplit", "alsonosplit"])}, dtype=any_string_dtype)
257285
tm.assert_frame_equal(result, exp)
258286

259-
s = Series(["some_equal_splits", "with_no_nans"])
287+
s = Series(["some_equal_splits", "with_no_nans"], dtype=any_string_dtype)
260288
result = s.str.rsplit("_", expand=True)
261-
exp = DataFrame({0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]})
289+
exp = DataFrame(
290+
{0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]},
291+
dtype=any_string_dtype,
292+
)
262293
tm.assert_frame_equal(result, exp)
263294

264295
result = s.str.rsplit("_", expand=True, n=2)
265-
exp = DataFrame({0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]})
296+
exp = DataFrame(
297+
{0: ["some", "with"], 1: ["equal", "no"], 2: ["splits", "nans"]},
298+
dtype=any_string_dtype,
299+
)
266300
tm.assert_frame_equal(result, exp)
267301

268302
result = s.str.rsplit("_", expand=True, n=1)
269-
exp = DataFrame({0: ["some_equal", "with_no"], 1: ["splits", "nans"]})
303+
exp = DataFrame(
304+
{0: ["some_equal", "with_no"], 1: ["splits", "nans"]}, dtype=any_string_dtype
305+
)
270306
tm.assert_frame_equal(result, exp)
271307

272-
s = Series(["some_splits", "with_index"], index=["preserve", "me"])
308+
s = Series(
309+
["some_splits", "with_index"], index=["preserve", "me"], dtype=any_string_dtype
310+
)
273311
result = s.str.rsplit("_", expand=True)
274312
exp = DataFrame(
275-
{0: ["some", "with"], 1: ["splits", "index"]}, index=["preserve", "me"]
313+
{0: ["some", "with"], 1: ["splits", "index"]},
314+
index=["preserve", "me"],
315+
dtype=any_string_dtype,
276316
)
277317
tm.assert_frame_equal(result, exp)
278318

@@ -297,30 +337,35 @@ def test_rsplit_to_multiindex_expand():
297337
assert result.nlevels == 2
298338

299339

300-
def test_split_nan_expand():
340+
def test_split_nan_expand(any_string_dtype):
301341
# gh-18450
302-
s = Series(["foo,bar,baz", np.nan])
342+
s = Series(["foo,bar,baz", np.nan], dtype=any_string_dtype)
303343
result = s.str.split(",", expand=True)
304-
exp = DataFrame([["foo", "bar", "baz"], [np.nan, np.nan, np.nan]])
344+
exp = DataFrame(
345+
[["foo", "bar", "baz"], [np.nan, np.nan, np.nan]], dtype=any_string_dtype
346+
)
305347
tm.assert_frame_equal(result, exp)
306348

307-
# check that these are actually np.nan and not None
349+
# check that these are actually np.nan/pd.NA and not None
308350
# TODO see GH 18463
309351
# tm.assert_frame_equal does not differentiate
310-
assert all(np.isnan(x) for x in result.iloc[1])
352+
if any_string_dtype == "object":
353+
assert all(np.isnan(x) for x in result.iloc[1])
354+
else:
355+
assert all(x is pd.NA for x in result.iloc[1])
311356

312357

313-
def test_split_with_name():
358+
def test_split_with_name(any_string_dtype):
314359
# GH 12617
315360

316361
# should preserve name
317-
s = Series(["a,b", "c,d"], name="xxx")
362+
s = Series(["a,b", "c,d"], name="xxx", dtype=any_string_dtype)
318363
res = s.str.split(",")
319364
exp = Series([["a", "b"], ["c", "d"]], name="xxx")
320365
tm.assert_series_equal(res, exp)
321366

322367
res = s.str.split(",", expand=True)
323-
exp = DataFrame([["a", "b"], ["c", "d"]])
368+
exp = DataFrame([["a", "b"], ["c", "d"]], dtype=any_string_dtype)
324369
tm.assert_frame_equal(res, exp)
325370

326371
idx = Index(["a,b", "c,d"], name="xxx")

0 commit comments

Comments
 (0)