Skip to content

Commit 4d875d1

Browse files
TST (string dtype): change any_string_dtype fixture to use actual dtype instances
1 parent 9c8c685 commit 4d875d1

File tree

5 files changed

+102
-37
lines changed

5 files changed

+102
-37
lines changed

pandas/conftest.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -1354,18 +1354,27 @@ def object_dtype(request):
13541354

13551355
@pytest.fixture(
13561356
params=[
1357-
"object",
1358-
"string[python]",
1359-
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
1360-
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
1361-
]
1357+
np.dtype("object"),
1358+
pd.StringDtype("python"),
1359+
pytest.param(pd.StringDtype("pyarrow"), marks=td.skip_if_no("pyarrow")),
1360+
pytest.param(
1361+
pd.StringDtype("pyarrow", na_value=np.nan), marks=td.skip_if_no("pyarrow")
1362+
),
1363+
],
1364+
ids=[
1365+
"string=object",
1366+
"string=string[python]",
1367+
"string=string[pyarrow]",
1368+
"string=str[pyarrow]",
1369+
],
13621370
)
13631371
def any_string_dtype(request):
13641372
"""
13651373
Parametrized fixture for string dtypes.
13661374
* 'object'
1367-
* 'string[python]'
1368-
* 'string[pyarrow]'
1375+
* 'string[python]' (NA variant)
1376+
* 'string[pyarrow]' (NA variant)
1377+
* 'str' (NaN variant, with pyarrow)
13691378
"""
13701379
return request.param
13711380

pandas/tests/strings/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
import pandas as pd
44

5-
object_pyarrow_numpy = ("object", "string[pyarrow_numpy]")
5+
6+
def is_object_or_nan_string_dtype(dtype):
7+
"""
8+
Check if string-like dtype is following NaN semantics, i.e. is object
9+
dtype or a NaN-variant of the StringDtype.
10+
"""
11+
return (isinstance(dtype, np.dtype) and dtype == "object") or (
12+
dtype.na_value is np.nan
13+
)
614

715

816
def _convert_na_value(ser, expected):

pandas/tests/strings/test_find_replace.py

+52-18
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from pandas.tests.strings import (
1515
_convert_na_value,
16-
object_pyarrow_numpy,
16+
is_object_or_nan_string_dtype,
1717
)
1818

1919
# --------------------------------------------------------------------------------------
@@ -33,7 +33,9 @@ def test_contains(any_string_dtype):
3333
pat = "mmm[_]+"
3434

3535
result = values.str.contains(pat)
36-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
36+
expected_dtype = (
37+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
38+
)
3739
expected = Series(
3840
np.array([False, np.nan, True, True, False], dtype=np.object_),
3941
dtype=expected_dtype,
@@ -52,7 +54,9 @@ def test_contains(any_string_dtype):
5254
dtype=any_string_dtype,
5355
)
5456
result = values.str.contains(pat)
55-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
57+
expected_dtype = (
58+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
59+
)
5660
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
5761
tm.assert_series_equal(result, expected)
5862

@@ -79,14 +83,18 @@ def test_contains(any_string_dtype):
7983
pat = "mmm[_]+"
8084

8185
result = values.str.contains(pat)
82-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
86+
expected_dtype = (
87+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
88+
)
8389
expected = Series(
8490
np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype
8591
)
8692
tm.assert_series_equal(result, expected)
8793

8894
result = values.str.contains(pat, na=False)
89-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
95+
expected_dtype = (
96+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
97+
)
9098
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
9199
tm.assert_series_equal(result, expected)
92100

@@ -171,7 +179,9 @@ def test_contains_moar(any_string_dtype):
171179
)
172180

173181
result = s.str.contains("a")
174-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
182+
expected_dtype = (
183+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
184+
)
175185
expected = Series(
176186
[False, False, False, True, True, False, np.nan, False, False, True],
177187
dtype=expected_dtype,
@@ -212,7 +222,9 @@ def test_contains_nan(any_string_dtype):
212222
s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype)
213223

214224
result = s.str.contains("foo", na=False)
215-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
225+
expected_dtype = (
226+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
227+
)
216228
expected = Series([False, False, False], dtype=expected_dtype)
217229
tm.assert_series_equal(result, expected)
218230

@@ -230,7 +242,9 @@ def test_contains_nan(any_string_dtype):
230242
tm.assert_series_equal(result, expected)
231243

232244
result = s.str.contains("foo")
233-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
245+
expected_dtype = (
246+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
247+
)
234248
expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype)
235249
tm.assert_series_equal(result, expected)
236250

@@ -675,7 +689,9 @@ def test_replace_regex_single_character(regex, any_string_dtype):
675689

676690
def test_match(any_string_dtype):
677691
# New match behavior introduced in 0.13
678-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
692+
expected_dtype = (
693+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
694+
)
679695

680696
values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype)
681697
result = values.str.match(".*(BAD[_]+).*(BAD)")
@@ -730,20 +746,26 @@ def test_match_na_kwarg(any_string_dtype):
730746
s = Series(["a", "b", np.nan], dtype=any_string_dtype)
731747

732748
result = s.str.match("a", na=False)
733-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
749+
expected_dtype = (
750+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
751+
)
734752
expected = Series([True, False, False], dtype=expected_dtype)
735753
tm.assert_series_equal(result, expected)
736754

737755
result = s.str.match("a")
738-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
756+
expected_dtype = (
757+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
758+
)
739759
expected = Series([True, False, np.nan], dtype=expected_dtype)
740760
tm.assert_series_equal(result, expected)
741761

742762

743763
def test_match_case_kwarg(any_string_dtype):
744764
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
745765
result = values.str.match("ab", case=False)
746-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
766+
expected_dtype = (
767+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
768+
)
747769
expected = Series([True, True, True, True], dtype=expected_dtype)
748770
tm.assert_series_equal(result, expected)
749771

@@ -759,7 +781,9 @@ def test_fullmatch(any_string_dtype):
759781
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
760782
)
761783
result = ser.str.fullmatch(".*BAD[_]+.*BAD")
762-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
784+
expected_dtype = (
785+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
786+
)
763787
expected = Series([True, False, np.nan, False], dtype=expected_dtype)
764788
tm.assert_series_equal(result, expected)
765789

@@ -768,7 +792,9 @@ def test_fullmatch_dollar_literal(any_string_dtype):
768792
# GH 56652
769793
ser = Series(["foo", "foo$foo", np.nan, "foo$"], dtype=any_string_dtype)
770794
result = ser.str.fullmatch("foo\\$")
771-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
795+
expected_dtype = (
796+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
797+
)
772798
expected = Series([False, False, np.nan, True], dtype=expected_dtype)
773799
tm.assert_series_equal(result, expected)
774800

@@ -778,14 +804,18 @@ def test_fullmatch_na_kwarg(any_string_dtype):
778804
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
779805
)
780806
result = ser.str.fullmatch(".*BAD[_]+.*BAD", na=False)
781-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
807+
expected_dtype = (
808+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
809+
)
782810
expected = Series([True, False, False, False], dtype=expected_dtype)
783811
tm.assert_series_equal(result, expected)
784812

785813

786814
def test_fullmatch_case_kwarg(any_string_dtype, performance_warning):
787815
ser = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
788-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
816+
expected_dtype = (
817+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
818+
)
789819

790820
expected = Series([True, False, False, False], dtype=expected_dtype)
791821

@@ -859,7 +889,9 @@ def test_find(any_string_dtype):
859889
ser = Series(
860890
["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXXX"], dtype=any_string_dtype
861891
)
862-
expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64"
892+
expected_dtype = (
893+
np.int64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
894+
)
863895

864896
result = ser.str.find("EF")
865897
expected = Series([4, 3, 1, 0, -1], dtype=expected_dtype)
@@ -911,7 +943,9 @@ def test_find_nan(any_string_dtype):
911943
ser = Series(
912944
["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype
913945
)
914-
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
946+
expected_dtype = (
947+
np.float64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
948+
)
915949

916950
result = ser.str.find("EF")
917951
expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype)

pandas/tests/strings/test_split_partition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from pandas.tests.strings import (
1616
_convert_na_value,
17-
object_pyarrow_numpy,
17+
is_object_or_nan_string_dtype,
1818
)
1919

2020

@@ -385,7 +385,7 @@ def test_split_nan_expand(any_string_dtype):
385385
# check that these are actually np.nan/pd.NA and not None
386386
# TODO see GH 18463
387387
# tm.assert_frame_equal does not differentiate
388-
if any_string_dtype in object_pyarrow_numpy:
388+
if is_object_or_nan_string_dtype(any_string_dtype):
389389
assert all(np.isnan(x) for x in result.iloc[1])
390390
else:
391391
assert all(x is pd.NA for x in result.iloc[1])

pandas/tests/strings/test_strings.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
import pandas._testing as tm
1616
from pandas.core.strings.accessor import StringMethods
17-
from pandas.tests.strings import object_pyarrow_numpy
17+
from pandas.tests.strings import is_object_or_nan_string_dtype
1818

1919

2020
@pytest.mark.parametrize("pattern", [0, True, Series(["foo", "bar"])])
@@ -41,7 +41,9 @@ def test_iter_raises():
4141
def test_count(any_string_dtype):
4242
ser = Series(["foo", "foofoo", np.nan, "foooofooofommmfoo"], dtype=any_string_dtype)
4343
result = ser.str.count("f[o]+")
44-
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
44+
expected_dtype = (
45+
np.float64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
46+
)
4547
expected = Series([1, 2, np.nan, 4], dtype=expected_dtype)
4648
tm.assert_series_equal(result, expected)
4749

@@ -93,7 +95,7 @@ def test_repeat_with_null(any_string_dtype, arg, repeat):
9395

9496
def test_empty_str_methods(any_string_dtype):
9597
empty_str = empty = Series(dtype=any_string_dtype)
96-
if any_string_dtype in object_pyarrow_numpy:
98+
if is_object_or_nan_string_dtype(any_string_dtype):
9799
empty_int = Series(dtype="int64")
98100
empty_bool = Series(dtype=bool)
99101
else:
@@ -207,7 +209,9 @@ def test_ismethods(method, expected, any_string_dtype):
207209
ser = Series(
208210
["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "], dtype=any_string_dtype
209211
)
210-
expected_dtype = "bool" if any_string_dtype in object_pyarrow_numpy else "boolean"
212+
expected_dtype = (
213+
"bool" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
214+
)
211215
expected = Series(expected, dtype=expected_dtype)
212216
result = getattr(ser.str, method)()
213217
tm.assert_series_equal(result, expected)
@@ -233,7 +237,9 @@ def test_isnumeric_unicode(method, expected, any_string_dtype):
233237
["A", "3", "¼", "★", "፸", "3", "four"], # noqa: RUF001
234238
dtype=any_string_dtype,
235239
)
236-
expected_dtype = "bool" if any_string_dtype in object_pyarrow_numpy else "boolean"
240+
expected_dtype = (
241+
"bool" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
242+
)
237243
expected = Series(expected, dtype=expected_dtype)
238244
result = getattr(ser.str, method)()
239245
tm.assert_series_equal(result, expected)
@@ -253,7 +259,9 @@ def test_isnumeric_unicode(method, expected, any_string_dtype):
253259
def test_isnumeric_unicode_missing(method, expected, any_string_dtype):
254260
values = ["A", np.nan, "¼", "★", np.nan, "3", "four"] # noqa: RUF001
255261
ser = Series(values, dtype=any_string_dtype)
256-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
262+
expected_dtype = (
263+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
264+
)
257265
expected = Series(expected, dtype=expected_dtype)
258266
result = getattr(ser.str, method)()
259267
tm.assert_series_equal(result, expected)
@@ -284,7 +292,9 @@ def test_len(any_string_dtype):
284292
dtype=any_string_dtype,
285293
)
286294
result = ser.str.len()
287-
expected_dtype = "float64" if any_string_dtype in object_pyarrow_numpy else "Int64"
295+
expected_dtype = (
296+
"float64" if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
297+
)
288298
expected = Series([3, 4, 6, np.nan, 8, 4, 1], dtype=expected_dtype)
289299
tm.assert_series_equal(result, expected)
290300

@@ -313,7 +323,9 @@ def test_index(method, sub, start, end, index_or_series, any_string_dtype, expec
313323
obj = index_or_series(
314324
["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"], dtype=any_string_dtype
315325
)
316-
expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64"
326+
expected_dtype = (
327+
np.int64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
328+
)
317329
expected = index_or_series(expected, dtype=expected_dtype)
318330

319331
result = getattr(obj.str, method)(sub, start, end)
@@ -354,7 +366,9 @@ def test_index_wrong_type_raises(index_or_series, any_string_dtype, method):
354366
)
355367
def test_index_missing(any_string_dtype, method, exp):
356368
ser = Series(["abcb", "ab", "bcbe", np.nan], dtype=any_string_dtype)
357-
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
369+
expected_dtype = (
370+
np.float64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
371+
)
358372

359373
result = getattr(ser.str, method)("b")
360374
expected = Series(exp + [np.nan], dtype=expected_dtype)

0 commit comments

Comments
 (0)