Skip to content

Commit 253169a

Browse files
jorisvandenbosscheWillAyd
authored andcommitted
String dtype: rename the storage options and add na_value keyword in StringDtype() (#59330)
* rename storage option and add na_value keyword * update init * fix propagating na_value to Array class + fix some tests * fix more tests * disallow pyarrow_numpy as option + fix more cases of checking storage to be pyarrow_numpy * restore pyarrow_numpy as option for now * linting * try fix typing * try fix typing * fix dtype equality to take into account the NaN vs NA * fix pickling of dtype * fix test_convert_dtypes * update expected result for dtype='string' * suppress typing error with _metadata attribute
1 parent 332624b commit 253169a

File tree

20 files changed

+194
-122
lines changed

20 files changed

+194
-122
lines changed

pandas/_libs/lib.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -2728,7 +2728,7 @@ def maybe_convert_objects(ndarray[object] objects,
27282728
if using_string_dtype() and is_string_array(objects, skipna=True):
27292729
from pandas.core.arrays.string_ import StringDtype
27302730

2731-
dtype = StringDtype(storage="pyarrow_numpy")
2731+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
27322732
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
27332733

27342734
seen.object_ = True

pandas/_testing/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -526,14 +526,14 @@ def shares_memory(left, right) -> bool:
526526
if (
527527
isinstance(left, ExtensionArray)
528528
and is_string_dtype(left.dtype)
529-
and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
529+
and left.dtype.storage == "pyarrow" # type: ignore[attr-defined]
530530
):
531531
# https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
532532
left = cast("ArrowExtensionArray", left)
533533
if (
534534
isinstance(right, ExtensionArray)
535535
and is_string_dtype(right.dtype)
536-
and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
536+
and right.dtype.storage == "pyarrow" # type: ignore[attr-defined]
537537
):
538538
right = cast("ArrowExtensionArray", right)
539539
left_pa_data = left._pa_array

pandas/core/arrays/arrow/array.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,8 @@ def __getitem__(self, item: PositionalIndexer):
570570
if isinstance(item, np.ndarray):
571571
if not len(item):
572572
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
573-
if self._dtype.name == "string" and self._dtype.storage in (
574-
"pyarrow",
575-
"pyarrow_numpy",
576-
):
573+
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
574+
# TODO(infer_string) should this be large_string?
577575
pa_dtype = pa.string()
578576
else:
579577
pa_dtype = self._dtype.pyarrow_dtype

pandas/core/arrays/string_.py

+68-21
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
import numpy as np
1010

11-
from pandas._config import get_option
11+
from pandas._config import (
12+
get_option,
13+
using_string_dtype,
14+
)
1215

1316
from pandas._libs import (
1417
lib,
@@ -80,8 +83,10 @@ class StringDtype(StorageExtensionDtype):
8083
8184
Parameters
8285
----------
83-
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
86+
storage : {"python", "pyarrow"}, optional
8487
If not given, the value of ``pd.options.mode.string_storage``.
88+
na_value : {np.nan, pd.NA}, default pd.NA
89+
Whether the dtype follows NaN or NA missing value semantics.
8590
8691
Attributes
8792
----------
@@ -108,30 +113,67 @@ class StringDtype(StorageExtensionDtype):
108113
# follows NumPy semantics, which uses nan.
109114
@property
110115
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
111-
if self.storage == "pyarrow_numpy":
112-
return np.nan
113-
else:
114-
return libmissing.NA
116+
return self._na_value
115117

116-
_metadata = ("storage",)
118+
_metadata = ("storage", "_na_value") # type: ignore[assignment]
117119

118-
def __init__(self, storage=None) -> None:
120+
def __init__(
121+
self,
122+
storage: str | None = None,
123+
na_value: libmissing.NAType | float = libmissing.NA,
124+
) -> None:
125+
# infer defaults
119126
if storage is None:
120-
infer_string = get_option("future.infer_string")
121-
if infer_string:
122-
storage = "pyarrow_numpy"
127+
if using_string_dtype():
128+
storage = "pyarrow"
123129
else:
124130
storage = get_option("mode.string_storage")
125-
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
131+
132+
if storage == "pyarrow_numpy":
133+
# TODO raise a deprecation warning
134+
storage = "pyarrow"
135+
na_value = np.nan
136+
137+
# validate options
138+
if storage not in {"python", "pyarrow"}:
126139
raise ValueError(
127-
f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
128-
f"Got {storage} instead."
140+
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
129141
)
130-
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under10p1:
142+
if storage == "pyarrow" and pa_version_under10p1:
131143
raise ImportError(
132144
"pyarrow>=10.0.1 is required for PyArrow backed StringArray."
133145
)
146+
147+
if isinstance(na_value, float) and np.isnan(na_value):
148+
# when passed a NaN value, always set to np.nan to ensure we use
149+
# a consistent NaN value (and we can use `dtype.na_value is np.nan`)
150+
na_value = np.nan
151+
elif na_value is not libmissing.NA:
152+
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
153+
134154
self.storage = storage
155+
self._na_value = na_value
156+
157+
def __eq__(self, other: object) -> bool:
158+
# we need to override the base class __eq__ because na_value (NA or NaN)
159+
# cannot be checked with normal `==`
160+
if isinstance(other, str):
161+
if other == self.name:
162+
return True
163+
try:
164+
other = self.construct_from_string(other)
165+
except TypeError:
166+
return False
167+
if isinstance(other, type(self)):
168+
return self.storage == other.storage and self.na_value is other.na_value
169+
return False
170+
171+
def __hash__(self) -> int:
172+
# need to override __hash__ as well because of overriding __eq__
173+
return super().__hash__()
174+
175+
def __reduce__(self):
176+
return StringDtype, (self.storage, self.na_value)
135177

136178
@property
137179
def type(self) -> type[str]:
@@ -176,6 +218,7 @@ def construct_from_string(cls, string) -> Self:
176218
elif string == "string[pyarrow]":
177219
return cls(storage="pyarrow")
178220
elif string == "string[pyarrow_numpy]":
221+
# TODO deprecate
179222
return cls(storage="pyarrow_numpy")
180223
else:
181224
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
@@ -200,7 +243,7 @@ def construct_array_type( # type: ignore[override]
200243

201244
if self.storage == "python":
202245
return StringArray
203-
elif self.storage == "pyarrow":
246+
elif self.storage == "pyarrow" and self._na_value is libmissing.NA:
204247
return ArrowStringArray
205248
else:
206249
return ArrowStringArrayNumpySemantics
@@ -212,13 +255,17 @@ def __from_arrow__(
212255
Construct StringArray from pyarrow Array/ChunkedArray.
213256
"""
214257
if self.storage == "pyarrow":
215-
from pandas.core.arrays.string_arrow import ArrowStringArray
258+
if self._na_value is libmissing.NA:
259+
from pandas.core.arrays.string_arrow import ArrowStringArray
260+
261+
return ArrowStringArray(array)
262+
else:
263+
from pandas.core.arrays.string_arrow import (
264+
ArrowStringArrayNumpySemantics,
265+
)
216266

217-
return ArrowStringArray(array)
218-
elif self.storage == "pyarrow_numpy":
219-
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics
267+
return ArrowStringArrayNumpySemantics(array)
220268

221-
return ArrowStringArrayNumpySemantics(array)
222269
else:
223270
import pyarrow
224271

pandas/core/arrays/string_arrow.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
125125
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
126126
_dtype: StringDtype # type: ignore[assignment]
127127
_storage = "pyarrow"
128+
_na_value: libmissing.NAType | float = libmissing.NA
128129

129130
def __init__(self, values) -> None:
130131
_chk_pyarrow_available()
@@ -134,7 +135,7 @@ def __init__(self, values) -> None:
134135
values = pc.cast(values, pa.large_string())
135136

136137
super().__init__(values)
137-
self._dtype = StringDtype(storage=self._storage)
138+
self._dtype = StringDtype(storage=self._storage, na_value=self._na_value)
138139

139140
if not pa.types.is_large_string(self._pa_array.type) and not (
140141
pa.types.is_dictionary(self._pa_array.type)
@@ -179,10 +180,7 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal
179180

180181
if dtype and not (isinstance(dtype, str) and dtype == "string"):
181182
dtype = pandas_dtype(dtype)
182-
assert isinstance(dtype, StringDtype) and dtype.storage in (
183-
"pyarrow",
184-
"pyarrow_numpy",
185-
)
183+
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
186184

187185
if isinstance(scalars, BaseMaskedArray):
188186
# avoid costly conversion to object dtype in ensure_string_array and
@@ -596,7 +594,8 @@ def _rank(
596594

597595

598596
class ArrowStringArrayNumpySemantics(ArrowStringArray):
599-
_storage = "pyarrow_numpy"
597+
_storage = "pyarrow"
598+
_na_value = np.nan
600599

601600
@classmethod
602601
def _result_converter(cls, values, na=None):

pandas/core/construction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def sanitize_array(
569569
if isinstance(data, str) and using_string_dtype() and original_dtype is None:
570570
from pandas.core.arrays.string_ import StringDtype
571571

572-
dtype = StringDtype("pyarrow_numpy")
572+
dtype = StringDtype("pyarrow", na_value=np.nan)
573573
data = construct_1d_arraylike_from_scalar(data, len(index), dtype)
574574

575575
return data
@@ -606,7 +606,7 @@ def sanitize_array(
606606
elif data.dtype.kind == "U" and using_string_dtype():
607607
from pandas.core.arrays.string_ import StringDtype
608608

609-
dtype = StringDtype(storage="pyarrow_numpy")
609+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
610610
subarr = dtype.construct_array_type()._from_sequence(data, dtype=dtype)
611611

612612
if subarr is data and copy:

pandas/core/dtypes/cast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]:
802802
if using_string_dtype():
803803
from pandas.core.arrays.string_ import StringDtype
804804

805-
dtype = StringDtype(storage="pyarrow_numpy")
805+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
806806

807807
elif isinstance(val, (np.datetime64, dt.datetime)):
808808
try:

pandas/core/indexes/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5620,9 +5620,10 @@ def equals(self, other: Any) -> bool:
56205620

56215621
if (
56225622
isinstance(self.dtype, StringDtype)
5623-
and self.dtype.storage == "pyarrow_numpy"
5623+
and self.dtype.na_value is np.nan
56245624
and other.dtype != self.dtype
56255625
):
5626+
# TODO(infer_string) can we avoid this special case?
56265627
# special case for object behavior
56275628
return other.equals(self.astype(object))
56285629

pandas/core/internals/construction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def ndarray_to_mgr(
376376
nb = new_block_2d(values, placement=bp, refs=refs)
377377
block_values = [nb]
378378
elif dtype is None and values.dtype.kind == "U" and using_string_dtype():
379-
dtype = StringDtype(storage="pyarrow_numpy")
379+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
380380

381381
obj_columns = list(values)
382382
block_values = [

pandas/core/reshape/encoding.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515

16+
from pandas._libs import missing as libmissing
1617
from pandas._libs.sparse import IntIndex
1718

1819
from pandas.core.dtypes.common import (
@@ -260,7 +261,7 @@ def _get_dummies_1d(
260261
dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment]
261262
elif (
262263
isinstance(input_dtype, StringDtype)
263-
and input_dtype.storage != "pyarrow_numpy"
264+
and input_dtype.na_value is libmissing.NA
264265
):
265266
dtype = pandas_dtype("boolean") # type: ignore[assignment]
266267
else:

pandas/core/reshape/merge.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2473,8 +2473,7 @@ def _factorize_keys(
24732473

24742474
elif isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype:
24752475
if (isinstance(lk.dtype, ArrowDtype) and is_string_dtype(lk.dtype)) or (
2476-
isinstance(lk.dtype, StringDtype)
2477-
and lk.dtype.storage in ["pyarrow", "pyarrow_numpy"]
2476+
isinstance(lk.dtype, StringDtype) and lk.dtype.storage == "pyarrow"
24782477
):
24792478
import pyarrow as pa
24802479
import pyarrow.compute as pc

pandas/core/tools/numeric.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
import numpy as np
1010

11-
from pandas._libs import lib
11+
from pandas._libs import (
12+
lib,
13+
missing as libmissing,
14+
)
1215
from pandas.util._exceptions import find_stack_level
1316
from pandas.util._validators import check_dtype_backend
1417

@@ -207,8 +210,6 @@ def to_numeric(
207210
else:
208211
values = arg
209212

210-
orig_values = values
211-
212213
# GH33013: for IntegerArray & FloatingArray extract non-null values for casting
213214
# save mask to reconstruct the full array after casting
214215
mask: npt.NDArray[np.bool_] | None = None
@@ -227,20 +228,15 @@ def to_numeric(
227228
values = values.view(np.int64)
228229
else:
229230
values = ensure_object(values)
230-
coerce_numeric = errors not in ("ignore", "raise")
231-
try:
232-
values, new_mask = lib.maybe_convert_numeric( # type: ignore[call-overload]
233-
values,
234-
set(),
235-
coerce_numeric=coerce_numeric,
236-
convert_to_masked_nullable=dtype_backend is not lib.no_default
237-
or isinstance(values_dtype, StringDtype)
238-
and not values_dtype.storage == "pyarrow_numpy",
239-
)
240-
except (ValueError, TypeError):
241-
if errors == "raise":
242-
raise
243-
values = orig_values
231+
coerce_numeric = errors != "raise"
232+
values, new_mask = lib.maybe_convert_numeric( # type: ignore[call-overload]
233+
values,
234+
set(),
235+
coerce_numeric=coerce_numeric,
236+
convert_to_masked_nullable=dtype_backend is not lib.no_default
237+
or isinstance(values_dtype, StringDtype)
238+
and values_dtype.na_value is libmissing.NA,
239+
)
244240

245241
if new_mask is not None:
246242
# Remove unnecessary values, is expected later anyway and enables
@@ -250,7 +246,7 @@ def to_numeric(
250246
dtype_backend is not lib.no_default
251247
and new_mask is None
252248
or isinstance(values_dtype, StringDtype)
253-
and not values_dtype.storage == "pyarrow_numpy"
249+
and values_dtype.na_value is libmissing.NA
254250
):
255251
new_mask = np.zeros(values.shape, dtype=np.bool_)
256252

pandas/io/_util.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import Callable
44

5+
import numpy as np
6+
57
from pandas.compat._optional import import_optional_dependency
68

79
import pandas as pd
@@ -29,6 +31,6 @@ def arrow_string_types_mapper() -> Callable:
2931
pa = import_optional_dependency("pyarrow")
3032

3133
return {
32-
pa.string(): pd.StringDtype(storage="pyarrow_numpy"),
33-
pa.large_string(): pd.StringDtype(storage="pyarrow_numpy"),
34+
pa.string(): pd.StringDtype(storage="pyarrow", na_value=np.nan),
35+
pa.large_string(): pd.StringDtype(storage="pyarrow", na_value=np.nan),
3436
}.get

0 commit comments

Comments
 (0)