Skip to content

Commit 89c8d7a

Browse files
TST (string dtype): change any_string_dtype fixture to use actual dtype instances (#59345)
* TST (string dtype): change any_string_dtype fixture to use actual dtype instances * avoid pyarrow import error during test collection * fix dtype equality in case pyarrow is not installed * keep using mode.string_storage as default for NA variant + more xfails * fix test_series_string_inference_storage_definition * remove no longer necessary xfails --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 4c39c08 commit 89c8d7a

File tree

12 files changed

+115
-50
lines changed

12 files changed

+115
-50
lines changed

pandas/conftest.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -1354,20 +1354,33 @@ def object_dtype(request):
13541354

13551355
@pytest.fixture(
13561356
params=[
1357-
"object",
1358-
"string[python]",
1359-
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
1360-
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
1361-
]
1357+
np.dtype("object"),
1358+
("python", pd.NA),
1359+
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
1360+
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
1361+
],
1362+
ids=[
1363+
"string=object",
1364+
"string=string[python]",
1365+
"string=string[pyarrow]",
1366+
"string=str[pyarrow]",
1367+
],
13621368
)
13631369
def any_string_dtype(request):
13641370
"""
13651371
Parametrized fixture for string dtypes.
13661372
* 'object'
1367-
* 'string[python]'
1368-
* 'string[pyarrow]'
1373+
* 'string[python]' (NA variant)
1374+
* 'string[pyarrow]' (NA variant)
1375+
* 'str' (NaN variant, with pyarrow)
13691376
"""
1370-
return request.param
1377+
if isinstance(request.param, np.dtype):
1378+
return request.param
1379+
else:
1380+
# need to instantiate the StringDtype here instead of in the params
1381+
# to avoid importing pyarrow during test collection
1382+
storage, na_value = request.param
1383+
return pd.StringDtype(storage, na_value)
13711384

13721385

13731386
@pytest.fixture(params=tm.DATETIME64_DTYPES)

pandas/core/arrays/string_.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(
129129
) -> None:
130130
# infer defaults
131131
if storage is None:
132-
if using_string_dtype():
132+
if using_string_dtype() and na_value is not libmissing.NA:
133133
storage = "pyarrow"
134134
else:
135135
storage = get_option("mode.string_storage")
@@ -167,7 +167,9 @@ def __eq__(self, other: object) -> bool:
167167
return True
168168
try:
169169
other = self.construct_from_string(other)
170-
except TypeError:
170+
except (TypeError, ImportError):
171+
# TypeError if `other` is not a valid string for StringDtype
172+
# ImportError if pyarrow is not installed for "string[pyarrow]"
171173
return False
172174
if isinstance(other, type(self)):
173175
return self.storage == other.storage and self.na_value is other.na_value

pandas/tests/arrays/categorical/test_constructors.py

-1
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,6 @@ def test_interval(self):
735735
tm.assert_numpy_array_equal(cat.codes, expected_codes)
736736
tm.assert_index_equal(cat.categories, idx)
737737

738-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
739738
def test_categorical_extension_array_nullable(self, nulls_fixture):
740739
# GH:
741740
arr = pd.arrays.StringArray._from_sequence(

pandas/tests/copy_view/test_array.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from pandas._config import using_string_dtype
5-
64
from pandas import (
75
DataFrame,
86
Series,
@@ -119,7 +117,6 @@ def test_dataframe_array_ea_dtypes():
119117
assert arr.flags.writeable is False
120118

121119

122-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
123120
def test_dataframe_array_string_dtype():
124121
df = DataFrame({"a": ["a", "b"]}, dtype="string")
125122
arr = np.asarray(df)

pandas/tests/copy_view/test_astype.py

-2
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def test_astype_numpy_to_ea():
8484
assert np.shares_memory(get_array(ser), get_array(result))
8585

8686

87-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
8887
@pytest.mark.parametrize(
8988
"dtype, new_dtype", [("object", "string"), ("string", "object")]
9089
)
@@ -98,7 +97,6 @@ def test_astype_string_and_object(dtype, new_dtype):
9897
tm.assert_frame_equal(df, df_orig)
9998

10099

101-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
102100
@pytest.mark.parametrize(
103101
"dtype, new_dtype", [("object", "string"), ("string", "object")]
104102
)

pandas/tests/dtypes/test_common.py

-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
import pandas.util._test_decorators as td
97

108
from pandas.core.dtypes.astype import astype_array
@@ -130,7 +128,6 @@ def test_dtype_equal(name1, dtype1, name2, dtype2):
130128
assert not com.is_dtype_equal(dtype1, dtype2)
131129

132130

133-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
134131
@pytest.mark.parametrize("name,dtype", list(dtypes.items()), ids=lambda x: str(x))
135132
def test_pyarrow_string_import_error(name, dtype):
136133
# GH-44276

pandas/tests/io/parser/test_index_col.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import numpy as np
1010
import pytest
1111

12+
from pandas._config import using_string_dtype
13+
1214
from pandas import (
1315
DataFrame,
1416
Index,
@@ -343,6 +345,7 @@ def test_infer_types_boolean_sum(all_parsers):
343345
tm.assert_frame_equal(result, expected, check_index_type=False)
344346

345347

348+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
346349
@pytest.mark.parametrize("dtype, val", [(object, "01"), ("int64", 1)])
347350
def test_specify_dtype_for_index_col(all_parsers, dtype, val, request):
348351
# GH#9435

pandas/tests/series/test_constructors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2118,7 +2118,7 @@ def test_series_string_inference_storage_definition(self):
21182118
# returning the NA string dtype, so expected is changed from
21192119
# "string[pyarrow_numpy]" to "string[pyarrow]"
21202120
pytest.importorskip("pyarrow")
2121-
expected = Series(["a", "b"], dtype="string[pyarrow]")
2121+
expected = Series(["a", "b"], dtype="string[python]")
21222122
with pd.option_context("future.infer_string", True):
21232123
result = Series(["a", "b"], dtype="string")
21242124
tm.assert_series_equal(result, expected)

pandas/tests/strings/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
import pandas as pd
44

5-
object_pyarrow_numpy = ("object", "string[pyarrow_numpy]")
5+
6+
def is_object_or_nan_string_dtype(dtype):
7+
"""
8+
Check if string-like dtype is following NaN semantics, i.e. is object
9+
dtype or a NaN-variant of the StringDtype.
10+
"""
11+
return (isinstance(dtype, np.dtype) and dtype == "object") or (
12+
dtype.na_value is np.nan
13+
)
614

715

816
def _convert_na_value(ser, expected):

pandas/tests/strings/test_find_replace.py

+52-18
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from pandas.tests.strings import (
1515
_convert_na_value,
16-
object_pyarrow_numpy,
16+
is_object_or_nan_string_dtype,
1717
)
1818

1919
# --------------------------------------------------------------------------------------
@@ -33,7 +33,9 @@ def test_contains(any_string_dtype):
3333
pat = "mmm[_]+"
3434

3535
result = values.str.contains(pat)
36-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
36+
expected_dtype = (
37+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
38+
)
3739
expected = Series(
3840
np.array([False, np.nan, True, True, False], dtype=np.object_),
3941
dtype=expected_dtype,
@@ -52,7 +54,9 @@ def test_contains(any_string_dtype):
5254
dtype=any_string_dtype,
5355
)
5456
result = values.str.contains(pat)
55-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
57+
expected_dtype = (
58+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
59+
)
5660
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
5761
tm.assert_series_equal(result, expected)
5862

@@ -79,14 +83,18 @@ def test_contains(any_string_dtype):
7983
pat = "mmm[_]+"
8084

8185
result = values.str.contains(pat)
82-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
86+
expected_dtype = (
87+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
88+
)
8389
expected = Series(
8490
np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype
8591
)
8692
tm.assert_series_equal(result, expected)
8793

8894
result = values.str.contains(pat, na=False)
89-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
95+
expected_dtype = (
96+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
97+
)
9098
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
9199
tm.assert_series_equal(result, expected)
92100

@@ -171,7 +179,9 @@ def test_contains_moar(any_string_dtype):
171179
)
172180

173181
result = s.str.contains("a")
174-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
182+
expected_dtype = (
183+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
184+
)
175185
expected = Series(
176186
[False, False, False, True, True, False, np.nan, False, False, True],
177187
dtype=expected_dtype,
@@ -212,7 +222,9 @@ def test_contains_nan(any_string_dtype):
212222
s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype)
213223

214224
result = s.str.contains("foo", na=False)
215-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
225+
expected_dtype = (
226+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
227+
)
216228
expected = Series([False, False, False], dtype=expected_dtype)
217229
tm.assert_series_equal(result, expected)
218230

@@ -230,7 +242,9 @@ def test_contains_nan(any_string_dtype):
230242
tm.assert_series_equal(result, expected)
231243

232244
result = s.str.contains("foo")
233-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
245+
expected_dtype = (
246+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
247+
)
234248
expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype)
235249
tm.assert_series_equal(result, expected)
236250

@@ -675,7 +689,9 @@ def test_replace_regex_single_character(regex, any_string_dtype):
675689

676690
def test_match(any_string_dtype):
677691
# New match behavior introduced in 0.13
678-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
692+
expected_dtype = (
693+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
694+
)
679695

680696
values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype)
681697
result = values.str.match(".*(BAD[_]+).*(BAD)")
@@ -730,20 +746,26 @@ def test_match_na_kwarg(any_string_dtype):
730746
s = Series(["a", "b", np.nan], dtype=any_string_dtype)
731747

732748
result = s.str.match("a", na=False)
733-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
749+
expected_dtype = (
750+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
751+
)
734752
expected = Series([True, False, False], dtype=expected_dtype)
735753
tm.assert_series_equal(result, expected)
736754

737755
result = s.str.match("a")
738-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
756+
expected_dtype = (
757+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
758+
)
739759
expected = Series([True, False, np.nan], dtype=expected_dtype)
740760
tm.assert_series_equal(result, expected)
741761

742762

743763
def test_match_case_kwarg(any_string_dtype):
744764
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
745765
result = values.str.match("ab", case=False)
746-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
766+
expected_dtype = (
767+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
768+
)
747769
expected = Series([True, True, True, True], dtype=expected_dtype)
748770
tm.assert_series_equal(result, expected)
749771

@@ -759,7 +781,9 @@ def test_fullmatch(any_string_dtype):
759781
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
760782
)
761783
result = ser.str.fullmatch(".*BAD[_]+.*BAD")
762-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
784+
expected_dtype = (
785+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
786+
)
763787
expected = Series([True, False, np.nan, False], dtype=expected_dtype)
764788
tm.assert_series_equal(result, expected)
765789

@@ -768,7 +792,9 @@ def test_fullmatch_dollar_literal(any_string_dtype):
768792
# GH 56652
769793
ser = Series(["foo", "foo$foo", np.nan, "foo$"], dtype=any_string_dtype)
770794
result = ser.str.fullmatch("foo\\$")
771-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
795+
expected_dtype = (
796+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
797+
)
772798
expected = Series([False, False, np.nan, True], dtype=expected_dtype)
773799
tm.assert_series_equal(result, expected)
774800

@@ -778,14 +804,18 @@ def test_fullmatch_na_kwarg(any_string_dtype):
778804
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
779805
)
780806
result = ser.str.fullmatch(".*BAD[_]+.*BAD", na=False)
781-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
807+
expected_dtype = (
808+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
809+
)
782810
expected = Series([True, False, False, False], dtype=expected_dtype)
783811
tm.assert_series_equal(result, expected)
784812

785813

786814
def test_fullmatch_case_kwarg(any_string_dtype, performance_warning):
787815
ser = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
788-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
816+
expected_dtype = (
817+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
818+
)
789819

790820
expected = Series([True, False, False, False], dtype=expected_dtype)
791821

@@ -859,7 +889,9 @@ def test_find(any_string_dtype):
859889
ser = Series(
860890
["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXXX"], dtype=any_string_dtype
861891
)
862-
expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64"
892+
expected_dtype = (
893+
np.int64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
894+
)
863895

864896
result = ser.str.find("EF")
865897
expected = Series([4, 3, 1, 0, -1], dtype=expected_dtype)
@@ -911,7 +943,9 @@ def test_find_nan(any_string_dtype):
911943
ser = Series(
912944
["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype
913945
)
914-
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
946+
expected_dtype = (
947+
np.float64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
948+
)
915949

916950
result = ser.str.find("EF")
917951
expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype)

pandas/tests/strings/test_split_partition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from pandas.tests.strings import (
1616
_convert_na_value,
17-
object_pyarrow_numpy,
17+
is_object_or_nan_string_dtype,
1818
)
1919

2020

@@ -385,7 +385,7 @@ def test_split_nan_expand(any_string_dtype):
385385
# check that these are actually np.nan/pd.NA and not None
386386
# TODO see GH 18463
387387
# tm.assert_frame_equal does not differentiate
388-
if any_string_dtype in object_pyarrow_numpy:
388+
if is_object_or_nan_string_dtype(any_string_dtype):
389389
assert all(np.isnan(x) for x in result.iloc[1])
390390
else:
391391
assert all(x is pd.NA for x in result.iloc[1])

0 commit comments

Comments
 (0)