Skip to content

Commit 8e3963b

Browse files
phoflmroeschke
authored andcommitted
Use NaN as na_value for new pyarrow_numpy StringDtype (pandas-dev#54585)
1 parent a47a4a8 commit 8e3963b

File tree

4 files changed

+45
-24
lines changed

4 files changed

+45
-24
lines changed

pandas/core/arrays/string_.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,14 @@ class StringDtype(StorageExtensionDtype):
101101
# base class "StorageExtensionDtype") with class variable
102102
name: ClassVar[str] = "string" # type: ignore[misc]
103103

104-
#: StringDtype().na_value uses pandas.NA
104+
#: StringDtype().na_value uses pandas.NA except the implementation that
105+
# follows NumPy semantics, which uses nan.
105106
@property
106-
def na_value(self) -> libmissing.NAType:
107-
return libmissing.NA
107+
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
108+
if self.storage == "pyarrow_numpy":
109+
return np.nan
110+
else:
111+
return libmissing.NA
108112

109113
_metadata = ("storage",)
110114

pandas/tests/arrays/string_/test_string.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
)
1818

1919

20+
def na_val(dtype):
21+
if dtype.storage == "pyarrow_numpy":
22+
return np.nan
23+
else:
24+
return pd.NA
25+
26+
2027
@pytest.fixture
2128
def dtype(string_storage):
2229
"""Fixture giving StringDtype from parametrized 'string_storage'"""
@@ -31,26 +38,34 @@ def cls(dtype):
3138

3239
def test_repr(dtype):
3340
df = pd.DataFrame({"A": pd.array(["a", pd.NA, "b"], dtype=dtype)})
34-
expected = " A\n0 a\n1 <NA>\n2 b"
41+
if dtype.storage == "pyarrow_numpy":
42+
expected = " A\n0 a\n1 NaN\n2 b"
43+
else:
44+
expected = " A\n0 a\n1 <NA>\n2 b"
3545
assert repr(df) == expected
3646

37-
expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
47+
if dtype.storage == "pyarrow_numpy":
48+
expected = "0 a\n1 NaN\n2 b\nName: A, dtype: string"
49+
else:
50+
expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
3851
assert repr(df.A) == expected
3952

4053
if dtype.storage == "pyarrow":
4154
arr_name = "ArrowStringArray"
55+
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
4256
elif dtype.storage == "pyarrow_numpy":
4357
arr_name = "ArrowStringArrayNumpySemantics"
58+
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string"
4459
else:
4560
arr_name = "StringArray"
46-
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
61+
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
4762
assert repr(df.A.array) == expected
4863

4964

5065
def test_none_to_nan(cls):
5166
a = cls._from_sequence(["a", None, "b"])
5267
assert a[1] is not None
53-
assert a[1] is pd.NA
68+
assert a[1] is na_val(a.dtype)
5469

5570

5671
def test_setitem_validates(cls):
@@ -213,13 +228,9 @@ def test_comparison_methods_scalar(comparison_op, dtype):
213228
other = "a"
214229
result = getattr(a, op_name)(other)
215230
if dtype.storage == "pyarrow_numpy":
216-
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
217-
expected = (
218-
pd.array(expected, dtype="boolean")
219-
.to_numpy(na_value=False)
220-
.astype(np.bool_)
221-
)
222-
tm.assert_numpy_array_equal(result, expected)
231+
expected = np.array([getattr(item, op_name)(other) for item in a])
232+
expected[1] = False
233+
tm.assert_numpy_array_equal(result, expected.astype(np.bool_))
223234
else:
224235
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
225236
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
@@ -415,7 +426,7 @@ def test_min_max(method, skipna, dtype, request):
415426
expected = "a" if method == "min" else "c"
416427
assert result == expected
417428
else:
418-
assert result is pd.NA
429+
assert result is na_val(arr.dtype)
419430

420431

421432
@pytest.mark.parametrize("method", ["min", "max"])
@@ -483,7 +494,7 @@ def test_arrow_roundtrip(dtype, string_storage2):
483494
expected = df.astype(f"string[{string_storage2}]")
484495
tm.assert_frame_equal(result, expected)
485496
# ensure the missing value is represented by NA and not np.nan or None
486-
assert result.loc[2, "a"] is pd.NA
497+
assert result.loc[2, "a"] is na_val(result["a"].dtype)
487498

488499

489500
def test_arrow_load_from_zero_chunks(dtype, string_storage2):
@@ -581,7 +592,7 @@ def test_astype_from_float_dtype(float_dtype, dtype):
581592
def test_to_numpy_returns_pdna_default(dtype):
582593
arr = pd.array(["a", pd.NA, "b"], dtype=dtype)
583594
result = np.array(arr)
584-
expected = np.array(["a", pd.NA, "b"], dtype=object)
595+
expected = np.array(["a", na_val(dtype), "b"], dtype=object)
585596
tm.assert_numpy_array_equal(result, expected)
586597

587598

@@ -621,7 +632,7 @@ def test_setitem_scalar_with_mask_validation(dtype):
621632
mask = np.array([False, True, False])
622633

623634
ser[mask] = None
624-
assert ser.array[1] is pd.NA
635+
assert ser.array[1] is na_val(ser.dtype)
625636

626637
# for other non-string we should also raise an error
627638
ser = pd.Series(["a", "b", "c"], dtype=dtype)

pandas/tests/strings/__init__.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Needed for new arrow string dtype
1+
import numpy as np
22

33
import pandas as pd
44

@@ -7,6 +7,9 @@
77

88
def _convert_na_value(ser, expected):
99
if ser.dtype != object:
10-
# GH#18463
11-
expected = expected.fillna(pd.NA)
10+
if ser.dtype.storage == "pyarrow_numpy":
11+
expected = expected.fillna(np.nan)
12+
else:
13+
# GH#18463
14+
expected = expected.fillna(pd.NA)
1215
return expected

pandas/tests/strings/test_split_partition.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
Series,
1313
_testing as tm,
1414
)
15-
from pandas.tests.strings import _convert_na_value
15+
from pandas.tests.strings import (
16+
_convert_na_value,
17+
object_pyarrow_numpy,
18+
)
1619

1720

1821
@pytest.mark.parametrize("method", ["split", "rsplit"])
@@ -113,8 +116,8 @@ def test_split_object_mixed(expand, method):
113116
def test_split_n(any_string_dtype, method, n):
114117
s = Series(["a b", pd.NA, "b c"], dtype=any_string_dtype)
115118
expected = Series([["a", "b"], pd.NA, ["b", "c"]])
116-
117119
result = getattr(s.str, method)(" ", n=n)
120+
expected = _convert_na_value(s, expected)
118121
tm.assert_series_equal(result, expected)
119122

120123

@@ -381,7 +384,7 @@ def test_split_nan_expand(any_string_dtype):
381384
# check that these are actually np.nan/pd.NA and not None
382385
# TODO see GH 18463
383386
# tm.assert_frame_equal does not differentiate
384-
if any_string_dtype == "object":
387+
if any_string_dtype in object_pyarrow_numpy:
385388
assert all(np.isnan(x) for x in result.iloc[1])
386389
else:
387390
assert all(x is pd.NA for x in result.iloc[1])

0 commit comments

Comments
 (0)