Skip to content

Commit 955eb2f

Browse files
TST (string dtype): change any_string_dtype fixture to use actual dtype instances (pandas-dev#59345)
* TST (string dtype): change any_string_dtype fixture to use actual dtype instances * avoid pyarrow import error during test collection * fix dtype equality in case pyarrow is not installed * keep using mode.string_storage as default for NA variant + more xfails * fix test_series_string_inference_storage_definition * remove no longer necessary xfails --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 8cebd65 commit 955eb2f

File tree

11 files changed

+115
-47
lines changed

11 files changed

+115
-47
lines changed

pandas/conftest.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -1306,20 +1306,33 @@ def object_dtype(request):
13061306

13071307
@pytest.fixture(
13081308
params=[
1309-
"object",
1310-
"string[python]",
1311-
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
1312-
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
1313-
]
1309+
np.dtype("object"),
1310+
("python", pd.NA),
1311+
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
1312+
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
1313+
],
1314+
ids=[
1315+
"string=object",
1316+
"string=string[python]",
1317+
"string=string[pyarrow]",
1318+
"string=str[pyarrow]",
1319+
],
13141320
)
13151321
def any_string_dtype(request):
13161322
"""
13171323
Parametrized fixture for string dtypes.
13181324
* 'object'
1319-
* 'string[python]'
1320-
* 'string[pyarrow]'
1325+
* 'string[python]' (NA variant)
1326+
* 'string[pyarrow]' (NA variant)
1327+
* 'str' (NaN variant, with pyarrow)
13211328
"""
1322-
return request.param
1329+
if isinstance(request.param, np.dtype):
1330+
return request.param
1331+
else:
1332+
# need to instantiate the StringDtype here instead of in the params
1333+
# to avoid importing pyarrow during test collection
1334+
storage, na_value = request.param
1335+
return pd.StringDtype(storage, na_value)
13231336

13241337

13251338
@pytest.fixture(params=tm.DATETIME64_DTYPES)

pandas/core/arrays/string_.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
) -> None:
125125
# infer defaults
126126
if storage is None:
127-
if using_string_dtype():
127+
if using_string_dtype() and na_value is not libmissing.NA:
128128
storage = "pyarrow"
129129
else:
130130
storage = get_option("mode.string_storage")
@@ -162,7 +162,9 @@ def __eq__(self, other: object) -> bool:
162162
return True
163163
try:
164164
other = self.construct_from_string(other)
165-
except TypeError:
165+
except (TypeError, ImportError):
166+
# TypeError if `other` is not a valid string for StringDtype
167+
# ImportError if pyarrow is not installed for "string[pyarrow]"
166168
return False
167169
if isinstance(other, type(self)):
168170
return self.storage == other.storage and self.na_value is other.na_value

pandas/tests/arrays/categorical/test_constructors.py

-1
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,6 @@ def test_interval(self):
743743
tm.assert_numpy_array_equal(cat.codes, expected_codes)
744744
tm.assert_index_equal(cat.categories, idx)
745745

746-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
747746
def test_categorical_extension_array_nullable(self, nulls_fixture):
748747
# GH:
749748
arr = pd.arrays.StringArray._from_sequence(

pandas/tests/copy_view/test_astype.py

-2
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def test_astype_numpy_to_ea():
9898
assert np.shares_memory(get_array(ser), get_array(result))
9999

100100

101-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
102101
@pytest.mark.parametrize(
103102
"dtype, new_dtype", [("object", "string"), ("string", "object")]
104103
)
@@ -116,7 +115,6 @@ def test_astype_string_and_object(using_copy_on_write, dtype, new_dtype):
116115
tm.assert_frame_equal(df, df_orig)
117116

118117

119-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
120118
@pytest.mark.parametrize(
121119
"dtype, new_dtype", [("object", "string"), ("string", "object")]
122120
)

pandas/tests/dtypes/test_common.py

-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
import pandas.util._test_decorators as td
97

108
from pandas.core.dtypes.astype import astype_array
@@ -130,7 +128,6 @@ def test_dtype_equal(name1, dtype1, name2, dtype2):
130128
assert not com.is_dtype_equal(dtype1, dtype2)
131129

132130

133-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
134131
@pytest.mark.parametrize("name,dtype", list(dtypes.items()), ids=lambda x: str(x))
135132
def test_pyarrow_string_import_error(name, dtype):
136133
# GH-44276

pandas/tests/io/parser/test_index_col.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import numpy as np
99
import pytest
1010

11+
from pandas._config import using_string_dtype
12+
1113
from pandas import (
1214
DataFrame,
1315
Index,
@@ -342,6 +344,7 @@ def test_infer_types_boolean_sum(all_parsers):
342344
tm.assert_frame_equal(result, expected, check_index_type=False)
343345

344346

347+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
345348
@pytest.mark.parametrize("dtype, val", [(object, "01"), ("int64", 1)])
346349
def test_specify_dtype_for_index_col(all_parsers, dtype, val, request):
347350
# GH#9435

pandas/tests/series/test_constructors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2138,7 +2138,7 @@ def test_series_string_inference_storage_definition(self):
21382138
# returning the NA string dtype, so expected is changed from
21392139
# "string[pyarrow_numpy]" to "string[pyarrow]"
21402140
pytest.importorskip("pyarrow")
2141-
expected = Series(["a", "b"], dtype="string[pyarrow]")
2141+
expected = Series(["a", "b"], dtype="string[python]")
21422142
with pd.option_context("future.infer_string", True):
21432143
result = Series(["a", "b"], dtype="string")
21442144
tm.assert_series_equal(result, expected)

pandas/tests/strings/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
import pandas as pd
44

5-
object_pyarrow_numpy = ("object", "string[pyarrow_numpy]")
5+
6+
def is_object_or_nan_string_dtype(dtype):
7+
"""
8+
Check if string-like dtype is following NaN semantics, i.e. is object
9+
dtype or a NaN-variant of the StringDtype.
10+
"""
11+
return (isinstance(dtype, np.dtype) and dtype == "object") or (
12+
dtype.na_value is np.nan
13+
)
614

715

816
def _convert_na_value(ser, expected):

pandas/tests/strings/test_find_replace.py

+52-18
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from pandas.tests.strings import (
1616
_convert_na_value,
17-
object_pyarrow_numpy,
17+
is_object_or_nan_string_dtype,
1818
)
1919

2020
# --------------------------------------------------------------------------------------
@@ -34,7 +34,9 @@ def test_contains(any_string_dtype):
3434
pat = "mmm[_]+"
3535

3636
result = values.str.contains(pat)
37-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
37+
expected_dtype = (
38+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
39+
)
3840
expected = Series(
3941
np.array([False, np.nan, True, True, False], dtype=np.object_),
4042
dtype=expected_dtype,
@@ -53,7 +55,9 @@ def test_contains(any_string_dtype):
5355
dtype=any_string_dtype,
5456
)
5557
result = values.str.contains(pat)
56-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
58+
expected_dtype = (
59+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
60+
)
5761
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
5862
tm.assert_series_equal(result, expected)
5963

@@ -80,14 +84,18 @@ def test_contains(any_string_dtype):
8084
pat = "mmm[_]+"
8185

8286
result = values.str.contains(pat)
83-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
87+
expected_dtype = (
88+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
89+
)
8490
expected = Series(
8591
np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype
8692
)
8793
tm.assert_series_equal(result, expected)
8894

8995
result = values.str.contains(pat, na=False)
90-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
96+
expected_dtype = (
97+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
98+
)
9199
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
92100
tm.assert_series_equal(result, expected)
93101

@@ -172,7 +180,9 @@ def test_contains_moar(any_string_dtype):
172180
)
173181

174182
result = s.str.contains("a")
175-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
183+
expected_dtype = (
184+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
185+
)
176186
expected = Series(
177187
[False, False, False, True, True, False, np.nan, False, False, True],
178188
dtype=expected_dtype,
@@ -213,7 +223,9 @@ def test_contains_nan(any_string_dtype):
213223
s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype)
214224

215225
result = s.str.contains("foo", na=False)
216-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
226+
expected_dtype = (
227+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
228+
)
217229
expected = Series([False, False, False], dtype=expected_dtype)
218230
tm.assert_series_equal(result, expected)
219231

@@ -231,7 +243,9 @@ def test_contains_nan(any_string_dtype):
231243
tm.assert_series_equal(result, expected)
232244

233245
result = s.str.contains("foo")
234-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
246+
expected_dtype = (
247+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
248+
)
235249
expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype)
236250
tm.assert_series_equal(result, expected)
237251

@@ -641,7 +655,9 @@ def test_replace_regex_single_character(regex, any_string_dtype):
641655

642656
def test_match(any_string_dtype):
643657
# New match behavior introduced in 0.13
644-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
658+
expected_dtype = (
659+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
660+
)
645661

646662
values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype)
647663
result = values.str.match(".*(BAD[_]+).*(BAD)")
@@ -696,20 +712,26 @@ def test_match_na_kwarg(any_string_dtype):
696712
s = Series(["a", "b", np.nan], dtype=any_string_dtype)
697713

698714
result = s.str.match("a", na=False)
699-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
715+
expected_dtype = (
716+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
717+
)
700718
expected = Series([True, False, False], dtype=expected_dtype)
701719
tm.assert_series_equal(result, expected)
702720

703721
result = s.str.match("a")
704-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
722+
expected_dtype = (
723+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
724+
)
705725
expected = Series([True, False, np.nan], dtype=expected_dtype)
706726
tm.assert_series_equal(result, expected)
707727

708728

709729
def test_match_case_kwarg(any_string_dtype):
710730
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
711731
result = values.str.match("ab", case=False)
712-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
732+
expected_dtype = (
733+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
734+
)
713735
expected = Series([True, True, True, True], dtype=expected_dtype)
714736
tm.assert_series_equal(result, expected)
715737

@@ -725,7 +747,9 @@ def test_fullmatch(any_string_dtype):
725747
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
726748
)
727749
result = ser.str.fullmatch(".*BAD[_]+.*BAD")
728-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
750+
expected_dtype = (
751+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
752+
)
729753
expected = Series([True, False, np.nan, False], dtype=expected_dtype)
730754
tm.assert_series_equal(result, expected)
731755

@@ -734,7 +758,9 @@ def test_fullmatch_dollar_literal(any_string_dtype):
734758
# GH 56652
735759
ser = Series(["foo", "foo$foo", np.nan, "foo$"], dtype=any_string_dtype)
736760
result = ser.str.fullmatch("foo\\$")
737-
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
761+
expected_dtype = (
762+
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
763+
)
738764
expected = Series([False, False, np.nan, True], dtype=expected_dtype)
739765
tm.assert_series_equal(result, expected)
740766

@@ -744,14 +770,18 @@ def test_fullmatch_na_kwarg(any_string_dtype):
744770
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
745771
)
746772
result = ser.str.fullmatch(".*BAD[_]+.*BAD", na=False)
747-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
773+
expected_dtype = (
774+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
775+
)
748776
expected = Series([True, False, False, False], dtype=expected_dtype)
749777
tm.assert_series_equal(result, expected)
750778

751779

752780
def test_fullmatch_case_kwarg(any_string_dtype):
753781
ser = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
754-
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
782+
expected_dtype = (
783+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
784+
)
755785

756786
expected = Series([True, False, False, False], dtype=expected_dtype)
757787

@@ -823,7 +853,9 @@ def test_find(any_string_dtype):
823853
ser = Series(
824854
["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXXX"], dtype=any_string_dtype
825855
)
826-
expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64"
856+
expected_dtype = (
857+
np.int64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
858+
)
827859

828860
result = ser.str.find("EF")
829861
expected = Series([4, 3, 1, 0, -1], dtype=expected_dtype)
@@ -875,7 +907,9 @@ def test_find_nan(any_string_dtype):
875907
ser = Series(
876908
["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype
877909
)
878-
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
910+
expected_dtype = (
911+
np.float64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
912+
)
879913

880914
result = ser.str.find("EF")
881915
expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype)

pandas/tests/strings/test_split_partition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from pandas.tests.strings import (
1616
_convert_na_value,
17-
object_pyarrow_numpy,
17+
is_object_or_nan_string_dtype,
1818
)
1919

2020

@@ -384,7 +384,7 @@ def test_split_nan_expand(any_string_dtype):
384384
# check that these are actually np.nan/pd.NA and not None
385385
# TODO see GH 18463
386386
# tm.assert_frame_equal does not differentiate
387-
if any_string_dtype in object_pyarrow_numpy:
387+
if is_object_or_nan_string_dtype(any_string_dtype):
388388
assert all(np.isnan(x) for x in result.iloc[1])
389389
else:
390390
assert all(x is pd.NA for x in result.iloc[1])

0 commit comments

Comments
 (0)