8
8
9
9
import numpy as np
10
10
11
- from pandas ._config import get_option
11
+ from pandas ._config import (
12
+ get_option ,
13
+ using_string_dtype ,
14
+ )
12
15
13
16
from pandas ._libs import (
14
17
lib ,
@@ -80,8 +83,10 @@ class StringDtype(StorageExtensionDtype):
80
83
81
84
Parameters
82
85
----------
83
- storage : {"python", "pyarrow", "pyarrow_numpy" }, optional
86
+ storage : {"python", "pyarrow"}, optional
84
87
If not given, the value of ``pd.options.mode.string_storage``.
88
+ na_value : {np.nan, pd.NA}, default pd.NA
89
+ Whether the dtype follows NaN or NA missing value semantics.
85
90
86
91
Attributes
87
92
----------
@@ -108,30 +113,67 @@ class StringDtype(StorageExtensionDtype):
108
113
# follows NumPy semantics, which uses nan.
109
114
@property
110
115
def na_value (self ) -> libmissing .NAType | float : # type: ignore[override]
111
- if self .storage == "pyarrow_numpy" :
112
- return np .nan
113
- else :
114
- return libmissing .NA
116
+ return self ._na_value
115
117
116
- _metadata = ("storage" ,)
118
+ _metadata = ("storage" , "_na_value" ) # type: ignore[assignment]
117
119
118
- def __init__ (self , storage = None ) -> None :
120
+ def __init__ (
121
+ self ,
122
+ storage : str | None = None ,
123
+ na_value : libmissing .NAType | float = libmissing .NA ,
124
+ ) -> None :
125
+ # infer defaults
119
126
if storage is None :
120
- infer_string = get_option ("future.infer_string" )
121
- if infer_string :
122
- storage = "pyarrow_numpy"
127
+ if using_string_dtype ():
128
+ storage = "pyarrow"
123
129
else :
124
130
storage = get_option ("mode.string_storage" )
125
- if storage not in {"python" , "pyarrow" , "pyarrow_numpy" }:
131
+
132
+ if storage == "pyarrow_numpy" :
133
+ # TODO raise a deprecation warning
134
+ storage = "pyarrow"
135
+ na_value = np .nan
136
+
137
+ # validate options
138
+ if storage not in {"python" , "pyarrow" }:
126
139
raise ValueError (
127
- f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
128
- f"Got { storage } instead."
140
+ f"Storage must be 'python' or 'pyarrow'. Got { storage } instead."
129
141
)
130
- if storage in ( "pyarrow" , "pyarrow_numpy" ) and pa_version_under10p1 :
142
+ if storage == "pyarrow" and pa_version_under10p1 :
131
143
raise ImportError (
132
144
"pyarrow>=10.0.1 is required for PyArrow backed StringArray."
133
145
)
146
+
147
+ if isinstance (na_value , float ) and np .isnan (na_value ):
148
+ # when passed a NaN value, always set to np.nan to ensure we use
149
+ # a consistent NaN value (and we can use `dtype.na_value is np.nan`)
150
+ na_value = np .nan
151
+ elif na_value is not libmissing .NA :
152
+ raise ValueError ("'na_value' must be np.nan or pd.NA, got {na_value}" )
153
+
134
154
self .storage = storage
155
+ self ._na_value = na_value
156
+
157
+ def __eq__ (self , other : object ) -> bool :
158
+ # we need to override the base class __eq__ because na_value (NA or NaN)
159
+ # cannot be checked with normal `==`
160
+ if isinstance (other , str ):
161
+ if other == self .name :
162
+ return True
163
+ try :
164
+ other = self .construct_from_string (other )
165
+ except TypeError :
166
+ return False
167
+ if isinstance (other , type (self )):
168
+ return self .storage == other .storage and self .na_value is other .na_value
169
+ return False
170
+
171
+ def __hash__ (self ) -> int :
172
+ # need to override __hash__ as well because of overriding __eq__
173
+ return super ().__hash__ ()
174
+
175
+ def __reduce__ (self ):
176
+ return StringDtype , (self .storage , self .na_value )
135
177
136
178
@property
137
179
def type (self ) -> type [str ]:
@@ -176,6 +218,7 @@ def construct_from_string(cls, string) -> Self:
176
218
elif string == "string[pyarrow]" :
177
219
return cls (storage = "pyarrow" )
178
220
elif string == "string[pyarrow_numpy]" :
221
+ # TODO deprecate
179
222
return cls (storage = "pyarrow_numpy" )
180
223
else :
181
224
raise TypeError (f"Cannot construct a '{ cls .__name__ } ' from '{ string } '" )
@@ -200,7 +243,7 @@ def construct_array_type( # type: ignore[override]
200
243
201
244
if self .storage == "python" :
202
245
return StringArray
203
- elif self .storage == "pyarrow" :
246
+ elif self .storage == "pyarrow" and self . _na_value is libmissing . NA :
204
247
return ArrowStringArray
205
248
else :
206
249
return ArrowStringArrayNumpySemantics
@@ -212,13 +255,17 @@ def __from_arrow__(
212
255
Construct StringArray from pyarrow Array/ChunkedArray.
213
256
"""
214
257
if self .storage == "pyarrow" :
215
- from pandas .core .arrays .string_arrow import ArrowStringArray
258
+ if self ._na_value is libmissing .NA :
259
+ from pandas .core .arrays .string_arrow import ArrowStringArray
260
+
261
+ return ArrowStringArray (array )
262
+ else :
263
+ from pandas .core .arrays .string_arrow import (
264
+ ArrowStringArrayNumpySemantics ,
265
+ )
216
266
217
- return ArrowStringArray (array )
218
- elif self .storage == "pyarrow_numpy" :
219
- from pandas .core .arrays .string_arrow import ArrowStringArrayNumpySemantics
267
+ return ArrowStringArrayNumpySemantics (array )
220
268
221
- return ArrowStringArrayNumpySemantics (array )
222
269
else :
223
270
import pyarrow
224
271
0 commit comments