Skip to content

Commit 89a7c34

Browse files
phoflmroeschke
authored andcommitted
REF: Move checks to object into a variable (pandas-dev#54536)
1 parent 28e0366 commit 89a7c34

File tree

3 files changed

+27
-23
lines changed

3 files changed

+27
-23
lines changed

pandas/tests/strings/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Needed for new arrow string dtype
2+
object_pyarrow_numpy = ("object",)

pandas/tests/strings/test_find_replace.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Series,
1212
_testing as tm,
1313
)
14+
from pandas.tests.strings import object_pyarrow_numpy
1415

1516
# --------------------------------------------------------------------------------------
1617
# str.contains
@@ -25,7 +26,7 @@ def test_contains(any_string_dtype):
2526
pat = "mmm[_]+"
2627

2728
result = values.str.contains(pat)
28-
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
29+
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
2930
expected = Series(
3031
np.array([False, np.nan, True, True, False], dtype=np.object_),
3132
dtype=expected_dtype,
@@ -44,7 +45,7 @@ def test_contains(any_string_dtype):
4445
dtype=any_string_dtype,
4546
)
4647
result = values.str.contains(pat)
47-
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
48+
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
4849
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
4950
tm.assert_series_equal(result, expected)
5051

@@ -71,14 +72,14 @@ def test_contains(any_string_dtype):
7172
pat = "mmm[_]+"
7273

7374
result = values.str.contains(pat)
74-
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
75+
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
7576
expected = Series(
7677
np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype
7778
)
7879
tm.assert_series_equal(result, expected)
7980

8081
result = values.str.contains(pat, na=False)
81-
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
82+
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
8283
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
8384
tm.assert_series_equal(result, expected)
8485

@@ -163,7 +164,7 @@ def test_contains_moar(any_string_dtype):
163164
)
164165

165166
result = s.str.contains("a")
166-
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
167+
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
167168
expected = Series(
168169
[False, False, False, True, True, False, np.nan, False, False, True],
169170
dtype=expected_dtype,
@@ -204,7 +205,7 @@ def test_contains_nan(any_string_dtype):
204205
s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype)
205206

206207
result = s.str.contains("foo", na=False)
207-
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
208+
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
208209
expected = Series([False, False, False], dtype=expected_dtype)
209210
tm.assert_series_equal(result, expected)
210211

@@ -220,7 +221,7 @@ def test_contains_nan(any_string_dtype):
220221
tm.assert_series_equal(result, expected)
221222

222223
result = s.str.contains("foo")
223-
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
224+
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
224225
expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype)
225226
tm.assert_series_equal(result, expected)
226227

@@ -648,7 +649,7 @@ def test_replace_regex_single_character(regex, any_string_dtype):
648649

649650
def test_match(any_string_dtype):
650651
# New match behavior introduced in 0.13
651-
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
652+
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
652653

653654
values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype)
654655
result = values.str.match(".*(BAD[_]+).*(BAD)")
@@ -703,20 +704,20 @@ def test_match_na_kwarg(any_string_dtype):
703704
s = Series(["a", "b", np.nan], dtype=any_string_dtype)
704705

705706
result = s.str.match("a", na=False)
706-
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
707+
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
707708
expected = Series([True, False, False], dtype=expected_dtype)
708709
tm.assert_series_equal(result, expected)
709710

710711
result = s.str.match("a")
711-
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
712+
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
712713
expected = Series([True, False, np.nan], dtype=expected_dtype)
713714
tm.assert_series_equal(result, expected)
714715

715716

716717
def test_match_case_kwarg(any_string_dtype):
717718
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
718719
result = values.str.match("ab", case=False)
719-
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
720+
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
720721
expected = Series([True, True, True, True], dtype=expected_dtype)
721722
tm.assert_series_equal(result, expected)
722723

@@ -732,7 +733,7 @@ def test_fullmatch(any_string_dtype):
732733
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
733734
)
734735
result = ser.str.fullmatch(".*BAD[_]+.*BAD")
735-
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
736+
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
736737
expected = Series([True, False, np.nan, False], dtype=expected_dtype)
737738
tm.assert_series_equal(result, expected)
738739

@@ -742,14 +743,14 @@ def test_fullmatch_na_kwarg(any_string_dtype):
742743
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
743744
)
744745
result = ser.str.fullmatch(".*BAD[_]+.*BAD", na=False)
745-
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
746+
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
746747
expected = Series([True, False, False, False], dtype=expected_dtype)
747748
tm.assert_series_equal(result, expected)
748749

749750

750751
def test_fullmatch_case_kwarg(any_string_dtype):
751752
ser = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
752-
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
753+
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
753754

754755
expected = Series([True, False, False, False], dtype=expected_dtype)
755756

@@ -877,7 +878,7 @@ def test_find_nan(any_string_dtype):
877878
ser = Series(
878879
["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype
879880
)
880-
expected_dtype = np.float64 if any_string_dtype == "object" else "Int64"
881+
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
881882

882883
result = ser.str.find("EF")
883884
expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype)

pandas/tests/strings/test_strings.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
import pandas._testing as tm
1616
from pandas.core.strings.accessor import StringMethods
17+
from pandas.tests.strings import object_pyarrow_numpy
1718

1819

1920
@pytest.mark.parametrize("pattern", [0, True, Series(["foo", "bar"])])
@@ -40,7 +41,7 @@ def test_iter_raises():
4041
def test_count(any_string_dtype):
4142
ser = Series(["foo", "foofoo", np.nan, "foooofooofommmfoo"], dtype=any_string_dtype)
4243
result = ser.str.count("f[o]+")
43-
expected_dtype = np.float64 if any_string_dtype == "object" else "Int64"
44+
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
4445
expected = Series([1, 2, np.nan, 4], dtype=expected_dtype)
4546
tm.assert_series_equal(result, expected)
4647

@@ -91,7 +92,7 @@ def test_repeat_with_null(any_string_dtype, arg, repeat):
9192

9293
def test_empty_str_methods(any_string_dtype):
9394
empty_str = empty = Series(dtype=any_string_dtype)
94-
if any_string_dtype == "object":
95+
if any_string_dtype in object_pyarrow_numpy:
9596
empty_int = Series(dtype="int64")
9697
empty_bool = Series(dtype=bool)
9798
else:
@@ -205,7 +206,7 @@ def test_ismethods(method, expected, any_string_dtype):
205206
ser = Series(
206207
["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "], dtype=any_string_dtype
207208
)
208-
expected_dtype = "bool" if any_string_dtype == "object" else "boolean"
209+
expected_dtype = "bool" if any_string_dtype in object_pyarrow_numpy else "boolean"
209210
expected = Series(expected, dtype=expected_dtype)
210211
result = getattr(ser.str, method)()
211212
tm.assert_series_equal(result, expected)
@@ -230,7 +231,7 @@ def test_isnumeric_unicode(method, expected, any_string_dtype):
230231
ser = Series(
231232
["A", "3", "¼", "★", "፸", "3", "four"], dtype=any_string_dtype # noqa: RUF001
232233
)
233-
expected_dtype = "bool" if any_string_dtype == "object" else "boolean"
234+
expected_dtype = "bool" if any_string_dtype in object_pyarrow_numpy else "boolean"
234235
expected = Series(expected, dtype=expected_dtype)
235236
result = getattr(ser.str, method)()
236237
tm.assert_series_equal(result, expected)
@@ -250,7 +251,7 @@ def test_isnumeric_unicode(method, expected, any_string_dtype):
250251
def test_isnumeric_unicode_missing(method, expected, any_string_dtype):
251252
values = ["A", np.nan, "¼", "★", np.nan, "3", "four"] # noqa: RUF001
252253
ser = Series(values, dtype=any_string_dtype)
253-
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
254+
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
254255
expected = Series(expected, dtype=expected_dtype)
255256
result = getattr(ser.str, method)()
256257
tm.assert_series_equal(result, expected)
@@ -280,7 +281,7 @@ def test_len(any_string_dtype):
280281
dtype=any_string_dtype,
281282
)
282283
result = ser.str.len()
283-
expected_dtype = "float64" if any_string_dtype == "object" else "Int64"
284+
expected_dtype = "float64" if any_string_dtype in object_pyarrow_numpy else "Int64"
284285
expected = Series([3, 4, 6, np.nan, 8, 4, 1], dtype=expected_dtype)
285286
tm.assert_series_equal(result, expected)
286287

@@ -309,7 +310,7 @@ def test_index(method, sub, start, end, index_or_series, any_string_dtype, expec
309310
obj = index_or_series(
310311
["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"], dtype=any_string_dtype
311312
)
312-
expected_dtype = np.int64 if any_string_dtype == "object" else "Int64"
313+
expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64"
313314
expected = index_or_series(expected, dtype=expected_dtype)
314315

315316
result = getattr(obj.str, method)(sub, start, end)
@@ -350,7 +351,7 @@ def test_index_wrong_type_raises(index_or_series, any_string_dtype, method):
350351
)
351352
def test_index_missing(any_string_dtype, method, exp):
352353
ser = Series(["abcb", "ab", "bcbe", np.nan], dtype=any_string_dtype)
353-
expected_dtype = np.float64 if any_string_dtype == "object" else "Int64"
354+
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
354355

355356
result = getattr(ser.str, method)("b")
356357
expected = Series(exp + [np.nan], dtype=expected_dtype)

0 commit comments

Comments
 (0)