Skip to content

Commit 14a6c9a

Browse files
authored
REF: Refactor conversion of na value (#54586)
1 parent 7915acb commit 14a6c9a

File tree

3 files changed

+25
-31
lines changed

3 files changed

+25
-31
lines changed

pandas/tests/strings/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,12 @@
11
# Needed for new arrow string dtype
2+
3+
import pandas as pd
4+
25
object_pyarrow_numpy = ("object",)
6+
7+
8+
def _convert_na_value(ser, expected):
9+
if ser.dtype != object:
10+
# GH#18463
11+
expected = expected.fillna(pd.NA)
12+
return expected

pandas/tests/strings/test_find_replace.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
Series,
1212
_testing as tm,
1313
)
14-
from pandas.tests.strings import object_pyarrow_numpy
14+
from pandas.tests.strings import (
15+
_convert_na_value,
16+
object_pyarrow_numpy,
17+
)
1518

1619
# --------------------------------------------------------------------------------------
1720
# str.contains
@@ -758,9 +761,7 @@ def test_findall(any_string_dtype):
758761
ser = Series(["fooBAD__barBAD", np.nan, "foo", "BAD"], dtype=any_string_dtype)
759762
result = ser.str.findall("BAD[_]*")
760763
expected = Series([["BAD__", "BAD"], np.nan, [], ["BAD"]])
761-
if ser.dtype != object:
762-
# GH#18463
763-
expected = expected.fillna(pd.NA)
764+
expected = _convert_na_value(ser, expected)
764765
tm.assert_series_equal(result, expected)
765766

766767

pandas/tests/strings/test_split_partition.py

+10-27
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Series,
1313
_testing as tm,
1414
)
15+
from pandas.tests.strings import _convert_na_value
1516

1617

1718
@pytest.mark.parametrize("method", ["split", "rsplit"])
@@ -20,9 +21,7 @@ def test_split(any_string_dtype, method):
2021

2122
result = getattr(values.str, method)("_")
2223
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
23-
if values.dtype != object:
24-
# GH#18463
25-
exp = exp.fillna(pd.NA)
24+
exp = _convert_na_value(values, exp)
2625
tm.assert_series_equal(result, exp)
2726

2827

@@ -32,9 +31,7 @@ def test_split_more_than_one_char(any_string_dtype, method):
3231
values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"], dtype=any_string_dtype)
3332
result = getattr(values.str, method)("__")
3433
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
35-
if values.dtype != object:
36-
# GH#18463
37-
exp = exp.fillna(pd.NA)
34+
exp = _convert_na_value(values, exp)
3835
tm.assert_series_equal(result, exp)
3936

4037
result = getattr(values.str, method)("__", expand=False)
@@ -46,9 +43,7 @@ def test_split_more_regex_split(any_string_dtype):
4643
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
4744
result = values.str.split("[,_]")
4845
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
49-
if values.dtype != object:
50-
# GH#18463
51-
exp = exp.fillna(pd.NA)
46+
exp = _convert_na_value(values, exp)
5247
tm.assert_series_equal(result, exp)
5348

5449

@@ -128,9 +123,7 @@ def test_rsplit(any_string_dtype):
128123
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
129124
result = values.str.rsplit("[,_]")
130125
exp = Series([["a,b_c"], ["c_d,e"], np.nan, ["f,g,h"]])
131-
if values.dtype != object:
132-
# GH#18463
133-
exp = exp.fillna(pd.NA)
126+
exp = _convert_na_value(values, exp)
134127
tm.assert_series_equal(result, exp)
135128

136129

@@ -139,9 +132,7 @@ def test_rsplit_max_number(any_string_dtype):
139132
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype)
140133
result = values.str.rsplit("_", n=1)
141134
exp = Series([["a_b", "c"], ["c_d", "e"], np.nan, ["f_g", "h"]])
142-
if values.dtype != object:
143-
# GH#18463
144-
exp = exp.fillna(pd.NA)
135+
exp = _convert_na_value(values, exp)
145136
tm.assert_series_equal(result, exp)
146137

147138

@@ -455,9 +446,7 @@ def test_partition_series_more_than_one_char(method, exp, any_string_dtype):
455446
s = Series(["a__b__c", "c__d__e", np.nan, "f__g__h", None], dtype=any_string_dtype)
456447
result = getattr(s.str, method)("__", expand=False)
457448
expected = Series(exp)
458-
if s.dtype != object:
459-
# GH#18463
460-
expected = expected.fillna(pd.NA)
449+
expected = _convert_na_value(s, expected)
461450
tm.assert_series_equal(result, expected)
462451

463452

@@ -480,9 +469,7 @@ def test_partition_series_none(any_string_dtype, method, exp):
480469
s = Series(["a b c", "c d e", np.nan, "f g h", None], dtype=any_string_dtype)
481470
result = getattr(s.str, method)(expand=False)
482471
expected = Series(exp)
483-
if s.dtype != object:
484-
# GH#18463
485-
expected = expected.fillna(pd.NA)
472+
expected = _convert_na_value(s, expected)
486473
tm.assert_series_equal(result, expected)
487474

488475

@@ -505,9 +492,7 @@ def test_partition_series_not_split(any_string_dtype, method, exp):
505492
s = Series(["abc", "cde", np.nan, "fgh", None], dtype=any_string_dtype)
506493
result = getattr(s.str, method)("_", expand=False)
507494
expected = Series(exp)
508-
if s.dtype != object:
509-
# GH#18463
510-
expected = expected.fillna(pd.NA)
495+
expected = _convert_na_value(s, expected)
511496
tm.assert_series_equal(result, expected)
512497

513498

@@ -531,9 +516,7 @@ def test_partition_series_unicode(any_string_dtype, method, exp):
531516

532517
result = getattr(s.str, method)("_", expand=False)
533518
expected = Series(exp)
534-
if s.dtype != object:
535-
# GH#18463
536-
expected = expected.fillna(pd.NA)
519+
expected = _convert_na_value(s, expected)
537520
tm.assert_series_equal(result, expected)
538521

539522

0 commit comments

Comments
 (0)