Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f25a09e

Browse files
authoredJul 29, 2024··
String dtype: rename the storage options and add na_value keyword in StringDtype() (#59330)
* rename storage option and add na_value keyword * update init * fix propagating na_value to Array class + fix some tests * fix more tests * disallow pyarrow_numpy as option + fix more cases of checking storage to be pyarrow_numpy * restore pyarrow_numpy as option for now * linting * try fix typing * try fix typing * fix dtype equality to take into account the NaN vs NA * fix pickling of dtype * fix test_convert_dtypes * update expected result for dtype='string' * suppress typing error with _metadata attribute
1 parent 56ea76a commit f25a09e

File tree

20 files changed

+176
-110
lines changed

20 files changed

+176
-110
lines changed
 

‎pandas/_libs/lib.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,7 @@ def maybe_convert_objects(ndarray[object] objects,
27022702
if using_string_dtype() and is_string_array(objects, skipna=True):
27032703
from pandas.core.arrays.string_ import StringDtype
27042704

2705-
dtype = StringDtype(storage="pyarrow_numpy")
2705+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
27062706
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
27072707

27082708
elif convert_to_nullable_dtype and is_string_array(objects, skipna=True):

‎pandas/_testing/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,14 +509,14 @@ def shares_memory(left, right) -> bool:
509509
if (
510510
isinstance(left, ExtensionArray)
511511
and is_string_dtype(left.dtype)
512-
and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
512+
and left.dtype.storage == "pyarrow" # type: ignore[attr-defined]
513513
):
514514
# https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
515515
left = cast("ArrowExtensionArray", left)
516516
if (
517517
isinstance(right, ExtensionArray)
518518
and is_string_dtype(right.dtype)
519-
and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
519+
and right.dtype.storage == "pyarrow" # type: ignore[attr-defined]
520520
):
521521
right = cast("ArrowExtensionArray", right)
522522
left_pa_data = left._pa_array

‎pandas/core/arrays/arrow/array.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -575,10 +575,8 @@ def __getitem__(self, item: PositionalIndexer):
575575
if isinstance(item, np.ndarray):
576576
if not len(item):
577577
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
578-
if self._dtype.name == "string" and self._dtype.storage in (
579-
"pyarrow",
580-
"pyarrow_numpy",
581-
):
578+
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
579+
# TODO(infer_string) should this be large_string?
582580
pa_dtype = pa.string()
583581
else:
584582
pa_dtype = self._dtype.pyarrow_dtype

‎pandas/core/arrays/string_.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import numpy as np
1111

12-
from pandas._config import get_option
12+
from pandas._config import (
13+
get_option,
14+
using_string_dtype,
15+
)
1316

1417
from pandas._libs import (
1518
lib,
@@ -81,8 +84,10 @@ class StringDtype(StorageExtensionDtype):
8184
8285
Parameters
8386
----------
84-
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
87+
storage : {"python", "pyarrow"}, optional
8588
If not given, the value of ``pd.options.mode.string_storage``.
89+
na_value : {np.nan, pd.NA}, default pd.NA
90+
Whether the dtype follows NaN or NA missing value semantics.
8691
8792
Attributes
8893
----------
@@ -113,30 +118,67 @@ class StringDtype(StorageExtensionDtype):
113118
# follows NumPy semantics, which uses nan.
114119
@property
115120
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
116-
if self.storage == "pyarrow_numpy":
117-
return np.nan
118-
else:
119-
return libmissing.NA
121+
return self._na_value
120122

121-
_metadata = ("storage",)
123+
_metadata = ("storage", "_na_value") # type: ignore[assignment]
122124

123-
def __init__(self, storage=None) -> None:
125+
def __init__(
126+
self,
127+
storage: str | None = None,
128+
na_value: libmissing.NAType | float = libmissing.NA,
129+
) -> None:
130+
# infer defaults
124131
if storage is None:
125-
infer_string = get_option("future.infer_string")
126-
if infer_string:
127-
storage = "pyarrow_numpy"
132+
if using_string_dtype():
133+
storage = "pyarrow"
128134
else:
129135
storage = get_option("mode.string_storage")
130-
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
136+
137+
if storage == "pyarrow_numpy":
138+
# TODO raise a deprecation warning
139+
storage = "pyarrow"
140+
na_value = np.nan
141+
142+
# validate options
143+
if storage not in {"python", "pyarrow"}:
131144
raise ValueError(
132-
f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
133-
f"Got {storage} instead."
145+
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
134146
)
135-
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under10p1:
147+
if storage == "pyarrow" and pa_version_under10p1:
136148
raise ImportError(
137149
"pyarrow>=10.0.1 is required for PyArrow backed StringArray."
138150
)
151+
152+
if isinstance(na_value, float) and np.isnan(na_value):
153+
# when passed a NaN value, always set to np.nan to ensure we use
154+
# a consistent NaN value (and we can use `dtype.na_value is np.nan`)
155+
na_value = np.nan
156+
elif na_value is not libmissing.NA:
157+
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
158+
139159
self.storage = storage
160+
self._na_value = na_value
161+
162+
def __eq__(self, other: object) -> bool:
163+
# we need to override the base class __eq__ because na_value (NA or NaN)
164+
# cannot be checked with normal `==`
165+
if isinstance(other, str):
166+
if other == self.name:
167+
return True
168+
try:
169+
other = self.construct_from_string(other)
170+
except TypeError:
171+
return False
172+
if isinstance(other, type(self)):
173+
return self.storage == other.storage and self.na_value is other.na_value
174+
return False
175+
176+
def __hash__(self) -> int:
177+
# need to override __hash__ as well because of overriding __eq__
178+
return super().__hash__()
179+
180+
def __reduce__(self):
181+
return StringDtype, (self.storage, self.na_value)
140182

141183
@property
142184
def type(self) -> type[str]:
@@ -181,6 +223,7 @@ def construct_from_string(cls, string) -> Self:
181223
elif string == "string[pyarrow]":
182224
return cls(storage="pyarrow")
183225
elif string == "string[pyarrow_numpy]":
226+
# TODO deprecate
184227
return cls(storage="pyarrow_numpy")
185228
else:
186229
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
@@ -205,7 +248,7 @@ def construct_array_type( # type: ignore[override]
205248

206249
if self.storage == "python":
207250
return StringArray
208-
elif self.storage == "pyarrow":
251+
elif self.storage == "pyarrow" and self._na_value is libmissing.NA:
209252
return ArrowStringArray
210253
else:
211254
return ArrowStringArrayNumpySemantics
@@ -217,13 +260,17 @@ def __from_arrow__(
217260
Construct StringArray from pyarrow Array/ChunkedArray.
218261
"""
219262
if self.storage == "pyarrow":
220-
from pandas.core.arrays.string_arrow import ArrowStringArray
263+
if self._na_value is libmissing.NA:
264+
from pandas.core.arrays.string_arrow import ArrowStringArray
265+
266+
return ArrowStringArray(array)
267+
else:
268+
from pandas.core.arrays.string_arrow import (
269+
ArrowStringArrayNumpySemantics,
270+
)
221271

222-
return ArrowStringArray(array)
223-
elif self.storage == "pyarrow_numpy":
224-
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics
272+
return ArrowStringArrayNumpySemantics(array)
225273

226-
return ArrowStringArrayNumpySemantics(array)
227274
else:
228275
import pyarrow
229276

‎pandas/core/arrays/string_arrow.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
131131
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
132132
_dtype: StringDtype # type: ignore[assignment]
133133
_storage = "pyarrow"
134+
_na_value: libmissing.NAType | float = libmissing.NA
134135

135136
def __init__(self, values) -> None:
136137
_chk_pyarrow_available()
@@ -140,7 +141,7 @@ def __init__(self, values) -> None:
140141
values = pc.cast(values, pa.large_string())
141142

142143
super().__init__(values)
143-
self._dtype = StringDtype(storage=self._storage)
144+
self._dtype = StringDtype(storage=self._storage, na_value=self._na_value)
144145

145146
if not pa.types.is_large_string(self._pa_array.type) and not (
146147
pa.types.is_dictionary(self._pa_array.type)
@@ -187,10 +188,7 @@ def _from_sequence(
187188

188189
if dtype and not (isinstance(dtype, str) and dtype == "string"):
189190
dtype = pandas_dtype(dtype)
190-
assert isinstance(dtype, StringDtype) and dtype.storage in (
191-
"pyarrow",
192-
"pyarrow_numpy",
193-
)
191+
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
194192

195193
if isinstance(scalars, BaseMaskedArray):
196194
# avoid costly conversion to object dtype in ensure_string_array and
@@ -597,7 +595,8 @@ def _rank(
597595

598596

599597
class ArrowStringArrayNumpySemantics(ArrowStringArray):
600-
_storage = "pyarrow_numpy"
598+
_storage = "pyarrow"
599+
_na_value = np.nan
601600

602601
@classmethod
603602
def _result_converter(cls, values, na=None):

‎pandas/core/construction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def sanitize_array(
574574
if isinstance(data, str) and using_string_dtype() and original_dtype is None:
575575
from pandas.core.arrays.string_ import StringDtype
576576

577-
dtype = StringDtype("pyarrow_numpy")
577+
dtype = StringDtype("pyarrow", na_value=np.nan)
578578
data = construct_1d_arraylike_from_scalar(data, len(index), dtype)
579579

580580
return data
@@ -608,7 +608,7 @@ def sanitize_array(
608608
elif data.dtype.kind == "U" and using_string_dtype():
609609
from pandas.core.arrays.string_ import StringDtype
610610

611-
dtype = StringDtype(storage="pyarrow_numpy")
611+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
612612
subarr = dtype.construct_array_type()._from_sequence(data, dtype=dtype)
613613

614614
if subarr is data and copy:

‎pandas/core/dtypes/cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]:
801801
if using_string_dtype():
802802
from pandas.core.arrays.string_ import StringDtype
803803

804-
dtype = StringDtype(storage="pyarrow_numpy")
804+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
805805

806806
elif isinstance(val, (np.datetime64, dt.datetime)):
807807
try:

‎pandas/core/indexes/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5453,9 +5453,10 @@ def equals(self, other: Any) -> bool:
54535453

54545454
if (
54555455
isinstance(self.dtype, StringDtype)
5456-
and self.dtype.storage == "pyarrow_numpy"
5456+
and self.dtype.na_value is np.nan
54575457
and other.dtype != self.dtype
54585458
):
5459+
# TODO(infer_string) can we avoid this special case?
54595460
# special case for object behavior
54605461
return other.equals(self.astype(object))
54615462

‎pandas/core/internals/construction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def ndarray_to_mgr(
302302
nb = new_block_2d(values, placement=bp, refs=refs)
303303
block_values = [nb]
304304
elif dtype is None and values.dtype.kind == "U" and using_string_dtype():
305-
dtype = StringDtype(storage="pyarrow_numpy")
305+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
306306

307307
obj_columns = list(values)
308308
block_values = [

‎pandas/core/reshape/encoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212

13+
from pandas._libs import missing as libmissing
1314
from pandas._libs.sparse import IntIndex
1415

1516
from pandas.core.dtypes.common import (
@@ -256,7 +257,7 @@ def _get_dummies_1d(
256257
dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment]
257258
elif (
258259
isinstance(input_dtype, StringDtype)
259-
and input_dtype.storage != "pyarrow_numpy"
260+
and input_dtype.na_value is libmissing.NA
260261
):
261262
dtype = pandas_dtype("boolean") # type: ignore[assignment]
262263
else:

‎pandas/core/reshape/merge.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2677,8 +2677,7 @@ def _factorize_keys(
26772677

26782678
elif isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype:
26792679
if (isinstance(lk.dtype, ArrowDtype) and is_string_dtype(lk.dtype)) or (
2680-
isinstance(lk.dtype, StringDtype)
2681-
and lk.dtype.storage in ["pyarrow", "pyarrow_numpy"]
2680+
isinstance(lk.dtype, StringDtype) and lk.dtype.storage == "pyarrow"
26822681
):
26832682
import pyarrow as pa
26842683
import pyarrow.compute as pc

‎pandas/core/tools/numeric.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
import numpy as np
99

10-
from pandas._libs import lib
10+
from pandas._libs import (
11+
lib,
12+
missing as libmissing,
13+
)
1114
from pandas.util._validators import check_dtype_backend
1215

1316
from pandas.core.dtypes.cast import maybe_downcast_numeric
@@ -218,7 +221,7 @@ def to_numeric(
218221
coerce_numeric=coerce_numeric,
219222
convert_to_masked_nullable=dtype_backend is not lib.no_default
220223
or isinstance(values_dtype, StringDtype)
221-
and not values_dtype.storage == "pyarrow_numpy",
224+
and values_dtype.na_value is libmissing.NA,
222225
)
223226

224227
if new_mask is not None:
@@ -229,7 +232,7 @@ def to_numeric(
229232
dtype_backend is not lib.no_default
230233
and new_mask is None
231234
or isinstance(values_dtype, StringDtype)
232-
and not values_dtype.storage == "pyarrow_numpy"
235+
and values_dtype.na_value is libmissing.NA
233236
):
234237
new_mask = np.zeros(values.shape, dtype=np.bool_)
235238

‎pandas/io/_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44

5+
import numpy as np
6+
57
from pandas.compat._optional import import_optional_dependency
68

79
import pandas as pd
@@ -32,6 +34,6 @@ def arrow_string_types_mapper() -> Callable:
3234
pa = import_optional_dependency("pyarrow")
3335

3436
return {
35-
pa.string(): pd.StringDtype(storage="pyarrow_numpy"),
36-
pa.large_string(): pd.StringDtype(storage="pyarrow_numpy"),
37+
pa.string(): pd.StringDtype(storage="pyarrow", na_value=np.nan),
38+
pa.large_string(): pd.StringDtype(storage="pyarrow", na_value=np.nan),
3739
}.get

‎pandas/tests/arrays/string_/test_string.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
)
2121

2222

23-
def na_val(dtype):
24-
if dtype.storage == "pyarrow_numpy":
25-
return np.nan
26-
else:
27-
return pd.NA
28-
29-
3023
@pytest.fixture
3124
def dtype(string_storage):
3225
"""Fixture giving StringDtype from parametrized 'string_storage'"""
@@ -39,24 +32,45 @@ def cls(dtype):
3932
return dtype.construct_array_type()
4033

4134

35+
def test_dtype_equality():
36+
pytest.importorskip("pyarrow")
37+
38+
dtype1 = pd.StringDtype("python")
39+
dtype2 = pd.StringDtype("pyarrow")
40+
dtype3 = pd.StringDtype("pyarrow", na_value=np.nan)
41+
42+
assert dtype1 == pd.StringDtype("python", na_value=pd.NA)
43+
assert dtype1 != dtype2
44+
assert dtype1 != dtype3
45+
46+
assert dtype2 == pd.StringDtype("pyarrow", na_value=pd.NA)
47+
assert dtype2 != dtype1
48+
assert dtype2 != dtype3
49+
50+
assert dtype3 == pd.StringDtype("pyarrow", na_value=np.nan)
51+
assert dtype3 == pd.StringDtype("pyarrow", na_value=float("nan"))
52+
assert dtype3 != dtype1
53+
assert dtype3 != dtype2
54+
55+
4256
def test_repr(dtype):
4357
df = pd.DataFrame({"A": pd.array(["a", pd.NA, "b"], dtype=dtype)})
44-
if dtype.storage == "pyarrow_numpy":
58+
if dtype.na_value is np.nan:
4559
expected = " A\n0 a\n1 NaN\n2 b"
4660
else:
4761
expected = " A\n0 a\n1 <NA>\n2 b"
4862
assert repr(df) == expected
4963

50-
if dtype.storage == "pyarrow_numpy":
64+
if dtype.na_value is np.nan:
5165
expected = "0 a\n1 NaN\n2 b\nName: A, dtype: string"
5266
else:
5367
expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
5468
assert repr(df.A) == expected
5569

56-
if dtype.storage == "pyarrow":
70+
if dtype.storage == "pyarrow" and dtype.na_value is pd.NA:
5771
arr_name = "ArrowStringArray"
5872
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
59-
elif dtype.storage == "pyarrow_numpy":
73+
elif dtype.storage == "pyarrow" and dtype.na_value is np.nan:
6074
arr_name = "ArrowStringArrayNumpySemantics"
6175
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string"
6276
else:
@@ -68,7 +82,7 @@ def test_repr(dtype):
6882
def test_none_to_nan(cls, dtype):
6983
a = cls._from_sequence(["a", None, "b"], dtype=dtype)
7084
assert a[1] is not None
71-
assert a[1] is na_val(a.dtype)
85+
assert a[1] is a.dtype.na_value
7286

7387

7488
def test_setitem_validates(cls, dtype):
@@ -225,7 +239,7 @@ def test_comparison_methods_scalar(comparison_op, dtype):
225239
a = pd.array(["a", None, "c"], dtype=dtype)
226240
other = "a"
227241
result = getattr(a, op_name)(other)
228-
if dtype.storage == "pyarrow_numpy":
242+
if dtype.na_value is np.nan:
229243
expected = np.array([getattr(item, op_name)(other) for item in a])
230244
if comparison_op == operator.ne:
231245
expected[1] = True
@@ -244,7 +258,7 @@ def test_comparison_methods_scalar_pd_na(comparison_op, dtype):
244258
a = pd.array(["a", None, "c"], dtype=dtype)
245259
result = getattr(a, op_name)(pd.NA)
246260

247-
if dtype.storage == "pyarrow_numpy":
261+
if dtype.na_value is np.nan:
248262
if operator.ne == comparison_op:
249263
expected = np.array([True, True, True])
250264
else:
@@ -271,7 +285,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
271285

272286
result = getattr(a, op_name)(other)
273287

274-
if dtype.storage == "pyarrow_numpy":
288+
if dtype.na_value is np.nan:
275289
expected_data = {
276290
"__eq__": [False, False, False],
277291
"__ne__": [True, True, True],
@@ -293,7 +307,7 @@ def test_comparison_methods_array(comparison_op, dtype):
293307
a = pd.array(["a", None, "c"], dtype=dtype)
294308
other = [None, None, "c"]
295309
result = getattr(a, op_name)(other)
296-
if dtype.storage == "pyarrow_numpy":
310+
if dtype.na_value is np.nan:
297311
if operator.ne == comparison_op:
298312
expected = np.array([True, True, False])
299313
else:
@@ -387,7 +401,7 @@ def test_astype_int(dtype):
387401
tm.assert_numpy_array_equal(result, expected)
388402

389403
arr = pd.array(["1", pd.NA, "3"], dtype=dtype)
390-
if dtype.storage == "pyarrow_numpy":
404+
if dtype.na_value is np.nan:
391405
err = ValueError
392406
msg = "cannot convert float NaN to integer"
393407
else:
@@ -441,7 +455,7 @@ def test_min_max(method, skipna, dtype):
441455
expected = "a" if method == "min" else "c"
442456
assert result == expected
443457
else:
444-
assert result is na_val(arr.dtype)
458+
assert result is arr.dtype.na_value
445459

446460

447461
@pytest.mark.parametrize("method", ["min", "max"])
@@ -490,7 +504,7 @@ def test_arrow_array(dtype):
490504
data = pd.array(["a", "b", "c"], dtype=dtype)
491505
arr = pa.array(data)
492506
expected = pa.array(list(data), type=pa.large_string(), from_pandas=True)
493-
if dtype.storage in ("pyarrow", "pyarrow_numpy") and pa_version_under12p0:
507+
if dtype.storage == "pyarrow" and pa_version_under12p0:
494508
expected = pa.chunked_array(expected)
495509
if dtype.storage == "python":
496510
expected = pc.cast(expected, pa.string())
@@ -522,7 +536,7 @@ def test_arrow_roundtrip(dtype, string_storage2, request, using_infer_string):
522536
expected = df.astype(f"string[{string_storage2}]")
523537
tm.assert_frame_equal(result, expected)
524538
# ensure the missing value is represented by NA and not np.nan or None
525-
assert result.loc[2, "a"] is na_val(result["a"].dtype)
539+
assert result.loc[2, "a"] is result["a"].dtype.na_value
526540

527541

528542
@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
@@ -556,10 +570,10 @@ def test_arrow_load_from_zero_chunks(
556570

557571

558572
def test_value_counts_na(dtype):
559-
if getattr(dtype, "storage", "") == "pyarrow":
560-
exp_dtype = "int64[pyarrow]"
561-
elif getattr(dtype, "storage", "") == "pyarrow_numpy":
573+
if dtype.na_value is np.nan:
562574
exp_dtype = "int64"
575+
elif dtype.storage == "pyarrow":
576+
exp_dtype = "int64[pyarrow]"
563577
else:
564578
exp_dtype = "Int64"
565579
arr = pd.array(["a", "b", "a", pd.NA], dtype=dtype)
@@ -573,10 +587,10 @@ def test_value_counts_na(dtype):
573587

574588

575589
def test_value_counts_with_normalize(dtype):
576-
if getattr(dtype, "storage", "") == "pyarrow":
577-
exp_dtype = "double[pyarrow]"
578-
elif getattr(dtype, "storage", "") == "pyarrow_numpy":
590+
if dtype.na_value is np.nan:
579591
exp_dtype = np.float64
592+
elif dtype.storage == "pyarrow":
593+
exp_dtype = "double[pyarrow]"
580594
else:
581595
exp_dtype = "Float64"
582596
ser = pd.Series(["a", "b", "a", pd.NA], dtype=dtype)
@@ -586,10 +600,10 @@ def test_value_counts_with_normalize(dtype):
586600

587601

588602
def test_value_counts_sort_false(dtype):
589-
if getattr(dtype, "storage", "") == "pyarrow":
590-
exp_dtype = "int64[pyarrow]"
591-
elif getattr(dtype, "storage", "") == "pyarrow_numpy":
603+
if dtype.na_value is np.nan:
592604
exp_dtype = "int64"
605+
elif dtype.storage == "pyarrow":
606+
exp_dtype = "int64[pyarrow]"
593607
else:
594608
exp_dtype = "Int64"
595609
ser = pd.Series(["a", "b", "c", "b"], dtype=dtype)
@@ -621,7 +635,7 @@ def test_astype_from_float_dtype(float_dtype, dtype):
621635
def test_to_numpy_returns_pdna_default(dtype):
622636
arr = pd.array(["a", pd.NA, "b"], dtype=dtype)
623637
result = np.array(arr)
624-
expected = np.array(["a", na_val(dtype), "b"], dtype=object)
638+
expected = np.array(["a", dtype.na_value, "b"], dtype=object)
625639
tm.assert_numpy_array_equal(result, expected)
626640

627641

@@ -661,7 +675,7 @@ def test_setitem_scalar_with_mask_validation(dtype):
661675
mask = np.array([False, True, False])
662676

663677
ser[mask] = None
664-
assert ser.array[1] is na_val(ser.dtype)
678+
assert ser.array[1] is ser.dtype.na_value
665679

666680
# for other non-string we should also raise an error
667681
ser = pd.Series(["a", "b", "c"], dtype=dtype)

‎pandas/tests/arrays/string_/test_string_arrow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def test_eq_all_na():
2929
def test_config(string_storage, request, using_infer_string):
3030
if using_infer_string and string_storage != "pyarrow_numpy":
3131
request.applymarker(pytest.mark.xfail(reason="infer string takes precedence"))
32+
if string_storage == "pyarrow_numpy":
33+
request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))
3234
with pd.option_context("string_storage", string_storage):
3335
assert StringDtype().storage == string_storage
3436
result = pd.array(["a", "b"])
@@ -260,6 +262,6 @@ def test_pickle_roundtrip(dtype):
260262
def test_string_dtype_error_message():
261263
# GH#55051
262264
pytest.importorskip("pyarrow")
263-
msg = "Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'."
265+
msg = "Storage must be 'python' or 'pyarrow'."
264266
with pytest.raises(ValueError, match=msg):
265267
StringDtype("bla")

‎pandas/tests/extension/base/methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ def test_value_counts_with_normalize(self, data):
6666
expected = pd.Series(0.0, index=result.index, name="proportion")
6767
expected[result > 0] = 1 / len(values)
6868

69-
if getattr(data.dtype, "storage", "") == "pyarrow" or isinstance(
69+
if isinstance(data.dtype, pd.StringDtype) and data.dtype.na_value is np.nan:
70+
# TODO: avoid special-casing
71+
expected = expected.astype("float64")
72+
elif getattr(data.dtype, "storage", "") == "pyarrow" or isinstance(
7073
data.dtype, pd.ArrowDtype
7174
):
7275
# TODO: avoid special-casing
7376
expected = expected.astype("double[pyarrow]")
74-
elif getattr(data.dtype, "storage", "") == "pyarrow_numpy":
75-
# TODO: avoid special-casing
76-
expected = expected.astype("float64")
7777
elif na_value_for_dtype(data.dtype) is pd.NA:
7878
# TODO(GH#44692): avoid special-casing
7979
expected = expected.astype("Float64")

‎pandas/tests/extension/test_string.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,15 @@ def data_for_grouping(dtype, chunked):
9696

9797
class TestStringArray(base.ExtensionTests):
9898
def test_eq_with_str(self, dtype):
99-
assert dtype == f"string[{dtype.storage}]"
10099
super().test_eq_with_str(dtype)
101100

101+
if dtype.na_value is pd.NA:
102+
# only the NA-variant supports parametrized string alias
103+
assert dtype == f"string[{dtype.storage}]"
104+
elif dtype.storage == "pyarrow":
105+
# TODO(infer_string) deprecate this
106+
assert dtype == "string[pyarrow_numpy]"
107+
102108
def test_is_not_string_type(self, dtype):
103109
# Different from BaseDtypeTests.test_is_not_string_type
104110
# because StringDtype is a string type
@@ -140,28 +146,21 @@ def _get_expected_exception(
140146
self, op_name: str, obj, other
141147
) -> type[Exception] | None:
142148
if op_name in ["__divmod__", "__rdivmod__"]:
143-
if isinstance(obj, pd.Series) and cast(
144-
StringDtype, tm.get_dtype(obj)
145-
).storage in [
146-
"pyarrow",
147-
"pyarrow_numpy",
148-
]:
149+
if (
150+
isinstance(obj, pd.Series)
151+
and cast(StringDtype, tm.get_dtype(obj)).storage == "pyarrow"
152+
):
149153
# TODO: re-raise as TypeError?
150154
return NotImplementedError
151-
elif isinstance(other, pd.Series) and cast(
152-
StringDtype, tm.get_dtype(other)
153-
).storage in [
154-
"pyarrow",
155-
"pyarrow_numpy",
156-
]:
155+
elif (
156+
isinstance(other, pd.Series)
157+
and cast(StringDtype, tm.get_dtype(other)).storage == "pyarrow"
158+
):
157159
# TODO: re-raise as TypeError?
158160
return NotImplementedError
159161
return TypeError
160162
elif op_name in ["__mod__", "__rmod__", "__pow__", "__rpow__"]:
161-
if cast(StringDtype, tm.get_dtype(obj)).storage in [
162-
"pyarrow",
163-
"pyarrow_numpy",
164-
]:
163+
if cast(StringDtype, tm.get_dtype(obj)).storage == "pyarrow":
165164
return NotImplementedError
166165
return TypeError
167166
elif op_name in ["__mul__", "__rmul__"]:
@@ -175,10 +174,7 @@ def _get_expected_exception(
175174
"__sub__",
176175
"__rsub__",
177176
]:
178-
if cast(StringDtype, tm.get_dtype(obj)).storage in [
179-
"pyarrow",
180-
"pyarrow_numpy",
181-
]:
177+
if cast(StringDtype, tm.get_dtype(obj)).storage == "pyarrow":
182178
import pyarrow as pa
183179

184180
# TODO: better to re-raise as TypeError?
@@ -190,18 +186,18 @@ def _get_expected_exception(
190186
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
191187
return (
192188
op_name in ["min", "max"]
193-
or ser.dtype.storage == "pyarrow_numpy" # type: ignore[union-attr]
189+
or ser.dtype.na_value is np.nan # type: ignore[union-attr]
194190
and op_name in ("any", "all")
195191
)
196192

197193
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
198194
dtype = cast(StringDtype, tm.get_dtype(obj))
199195
if op_name in ["__add__", "__radd__"]:
200196
cast_to = dtype
197+
elif dtype.na_value is np.nan:
198+
cast_to = np.bool_ # type: ignore[assignment]
201199
elif dtype.storage == "pyarrow":
202200
cast_to = "boolean[pyarrow]" # type: ignore[assignment]
203-
elif dtype.storage == "pyarrow_numpy":
204-
cast_to = np.bool_ # type: ignore[assignment]
205201
else:
206202
cast_to = "boolean" # type: ignore[assignment]
207203
return pointwise_result.astype(cast_to)

‎pandas/tests/frame/methods/test_convert_dtypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_convert_dtypes(
1818
# Just check that it works for DataFrame here
1919
if using_infer_string:
2020
string_storage = "pyarrow_numpy"
21+
2122
df = pd.DataFrame(
2223
{
2324
"a": pd.Series([1, 2, 3], dtype=np.dtype("int32")),

‎pandas/tests/series/test_constructors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,9 +2113,12 @@ def test_series_string_inference_array_string_dtype(self):
21132113
tm.assert_series_equal(ser, expected)
21142114

21152115
def test_series_string_inference_storage_definition(self):
2116-
# GH#54793
2116+
# https://github.com/pandas-dev/pandas/issues/54793
2117+
# but after PDEP-14 (string dtype), it was decided to keep dtype="string"
2118+
# returning the NA string dtype, so expected is changed from
2119+
# "string[pyarrow_numpy]" to "string[pyarrow]"
21172120
pytest.importorskip("pyarrow")
2118-
expected = Series(["a", "b"], dtype="string[pyarrow_numpy]")
2121+
expected = Series(["a", "b"], dtype="string[pyarrow]")
21192122
with pd.option_context("future.infer_string", True):
21202123
result = Series(["a", "b"], dtype="string")
21212124
tm.assert_series_equal(result, expected)

‎pandas/tests/strings/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def _convert_na_value(ser, expected):
99
if ser.dtype != object:
10-
if ser.dtype.storage == "pyarrow_numpy":
10+
if ser.dtype.na_value is np.nan:
1111
expected = expected.fillna(np.nan)
1212
else:
1313
# GH#18463

0 commit comments

Comments
 (0)
Please sign in to comment.