Skip to content

Commit 206d2f0

Browse files
committed
add NumpyStringArray and string[numpy] dtype
1 parent 56ae252 commit 206d2f0

File tree

9 files changed

+133
-56
lines changed

9 files changed

+133
-56
lines changed

asv_bench/benchmarks/strings.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515

1616

1717
class Dtypes:
18-
params = ["str", "string[python]", "string[pyarrow]", StringDType()]
18+
params = [
19+
"str",
20+
"string[python]",
21+
"string[pyarrow]",
22+
"string[numpy]",
23+
StringDType(),
24+
]
1925
param_names = ["dtype"]
2026
dtype_mapping = {
2127
"str": "str",
2228
"string[python]": object,
2329
"string[pyarrow]": object,
30+
"string[numpy]": StringDType(),
2431
StringDType(): StringDType(),
2532
}
2633

@@ -37,14 +44,15 @@ def setup(self, dtype):
3744
class Construction:
3845
params = (
3946
["series", "frame", "categorical_series"],
40-
["str", "string[python]", "string[pyarrow]", StringDType()],
47+
["str", "string[python]", "string[pyarrow]", "string[numpy]", StringDType()],
4148
)
4249
param_names = ["pd_type", "dtype"]
4350
pd_mapping = {"series": Series, "frame": DataFrame, "categorical_series": Series}
4451
dtype_mapping = {
4552
"str": "str",
4653
"string[python]": object,
4754
"string[pyarrow]": object,
55+
"string[numpy]": StringDType(),
4856
StringDType(): StringDType(),
4957
}
5058

pandas/_libs/lib.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -1860,7 +1860,7 @@ cdef class StringValidator(Validator):
18601860
return isinstance(value, str)
18611861

18621862
cdef bint is_array_typed(self) except -1:
1863-
return issubclass(self.dtype.type, np.str_)
1863+
return issubclass(self.dtype.type, (np.str_, str))
18641864

18651865

18661866
cpdef bint is_string_array(ndarray values, bint skipna=False):

pandas/core/arrays/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
period_array,
1818
)
1919
from pandas.core.arrays.sparse import SparseArray
20-
from pandas.core.arrays.string_ import StringArray
20+
from pandas.core.arrays.string_ import (
21+
NumpyStringArray,
22+
ObjectStringArray,
23+
StringArray,
24+
)
2125
from pandas.core.arrays.string_arrow import ArrowStringArray
2226
from pandas.core.arrays.timedeltas import TimedeltaArray
2327

@@ -39,5 +43,7 @@
3943
"period_array",
4044
"SparseArray",
4145
"StringArray",
46+
"ObjectStringArray",
47+
"NumpyStringArray",
4248
"TimedeltaArray",
4349
]

pandas/core/arrays/string_.py

+95-44
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
missing as libmissing,
1515
)
1616
from pandas._libs.arrays import NDArrayBacked
17-
from pandas.compat import pa_version_under7p0
17+
from pandas.compat import (
18+
is_numpy_dev,
19+
pa_version_under7p0,
20+
)
1821
from pandas.compat.numpy import function as nv
1922
from pandas.util._decorators import doc
2023

@@ -24,6 +27,7 @@
2427
register_extension_dtype,
2528
)
2629
from pandas.core.dtypes.common import (
30+
get_string_dtype,
2731
is_array_like,
2832
is_bool_dtype,
2933
is_integer_dtype,
@@ -76,7 +80,7 @@ class StringDtype(StorageExtensionDtype):
7680
7781
Parameters
7882
----------
79-
storage : {"python", "pyarrow"}, optional
83+
storage : {"python", "pyarrow", "numpy"}, optional
8084
If not given, the value of ``pd.options.mode.string_storage``.
8185
8286
Attributes
@@ -108,14 +112,17 @@ def na_value(self) -> libmissing.NAType:
108112
def __init__(self, storage=None) -> None:
109113
if storage is None:
110114
storage = get_option("mode.string_storage")
111-
if storage not in {"python", "pyarrow"}:
115+
if storage not in {"python", "pyarrow", "numpy"}:
112116
raise ValueError(
113-
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
117+
"Storage must be 'python', 'pyarrow', or 'numpy'. "
118+
"Got {storage} instead."
114119
)
115120
if storage == "pyarrow" and pa_version_under7p0:
116121
raise ImportError(
117122
"pyarrow>=7.0.0 is required for PyArrow backed StringArray."
118123
)
124+
if storage == "numpy" and not is_numpy_dev:
125+
raise ImportError("NumPy backed string storage requires numpy dev")
119126
self.storage = storage
120127

121128
@property
@@ -139,6 +146,7 @@ def construct_from_string(cls, string):
139146
``'string'`` pd.options.mode.string_storage, default python
140147
``'string[python]'`` python
141148
``'string[pyarrow]'`` pyarrow
149+
``'string[numpy]'`` numpy
142150
========================== ==============================================
143151
144152
Returns
@@ -160,6 +168,8 @@ def construct_from_string(cls, string):
160168
return cls(storage="python")
161169
elif string == "string[pyarrow]":
162170
return cls(storage="pyarrow")
171+
elif string == "string[numpy]":
172+
return cls(storage="numpy")
163173
else:
164174
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
165175

@@ -179,9 +189,13 @@ def construct_array_type( # type: ignore[override]
179189
from pandas.core.arrays.string_arrow import ArrowStringArray
180190

181191
if self.storage == "python":
182-
return StringArray
183-
else:
192+
return ObjectStringArray
193+
elif self.storage == "pyarrow":
184194
return ArrowStringArray
195+
elif self.storage == "numpy":
196+
return NumpyStringArray
197+
else:
198+
raise NotImplementedError
185199

186200
def __from_arrow__(
187201
self, array: pyarrow.Array | pyarrow.ChunkedArray
@@ -231,7 +245,7 @@ def tolist(self):
231245

232246
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
233247
# incompatible with definition in base class "ExtensionArray"
234-
class StringArray(BaseStringArray, PandasArray): # type: ignore[misc]
248+
class BaseNumpyStringArray(BaseStringArray, PandasArray): # type: ignore[misc]
235249
"""
236250
Extension array for string data.
237251
@@ -321,54 +335,23 @@ def __init__(self, values, copy: bool = False) -> None:
321335
super().__init__(values, copy=copy)
322336
if not isinstance(values, type(self)):
323337
self._validate()
324-
NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage="python"))
338+
NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage=self._storage))
325339

326340
def _validate(self):
327341
"""Validate that we only store NA or strings."""
328342
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True):
329343
raise ValueError("StringArray requires a sequence of strings or pandas.NA")
330-
if self._ndarray.dtype != "object":
344+
if self._ndarray.dtype != self._cache_dtype:
331345
raise ValueError(
332-
"StringArray requires a sequence of strings or pandas.NA. Got "
346+
f"{type(self).__name__} requires a sequence of strings or "
347+
"pandas.NA convertible to a NumPy array with dtype "
348+
f"{self._cache_dtype}. Got "
333349
f"'{self._ndarray.dtype}' dtype instead."
334350
)
335-
# Check to see if need to convert Na values to pd.NA
336-
if self._ndarray.ndim > 2:
337-
# Ravel if ndims > 2 b/c no cythonized version available
338-
lib.convert_nans_to_NA(self._ndarray.ravel("K"))
339-
else:
340-
lib.convert_nans_to_NA(self._ndarray)
341351

342352
@classmethod
343353
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False):
344-
if dtype and not (isinstance(dtype, str) and dtype == "string"):
345-
dtype = pandas_dtype(dtype)
346-
assert isinstance(dtype, StringDtype) and dtype.storage == "python"
347-
348-
from pandas.core.arrays.masked import BaseMaskedArray
349-
350-
if isinstance(scalars, BaseMaskedArray):
351-
# avoid costly conversion to object dtype
352-
na_values = scalars._mask
353-
result = scalars._data
354-
result = lib.ensure_string_array(result, copy=copy, convert_na_value=False)
355-
result[na_values] = libmissing.NA
356-
357-
else:
358-
if hasattr(scalars, "type"):
359-
# pyarrow array; we cannot rely on the "to_numpy" check in
360-
# ensure_string_array because calling scalars.to_numpy would set
361-
# zero_copy_only to True which caused problems see GH#52076
362-
scalars = np.array(scalars)
363-
# convert non-na-likes to str, and nan-likes to StringDtype().na_value
364-
result = lib.ensure_string_array(scalars, na_value=libmissing.NA, copy=copy)
365-
366-
# Manually creating new array avoids the validation step in the __init__, so is
367-
# faster. Refactor need for validation?
368-
new_string_array = cls.__new__(cls)
369-
NDArrayBacked.__init__(new_string_array, result, StringDtype(storage="python"))
370-
371-
return new_string_array
354+
raise NotImplementedError("_from_sequence must be implemented in subclasses")
372355

373356
@classmethod
374357
def _from_sequence_of_strings(
@@ -612,3 +595,71 @@ def _str_map(
612595
# or .findall returns a list).
613596
# -> We don't know the result type. E.g. `.get` can return anything.
614597
return lib.map_infer_mask(arr, f, mask.view("uint8"))
598+
599+
600+
class ObjectStringArray(BaseNumpyStringArray):
601+
_cache_dtype = "object"
602+
_storage = "python"
603+
604+
def _validate(self):
605+
super()._validate()
606+
# Check to see if need to convert Na values to pd.NA
607+
if self._ndarray.ndim > 2:
608+
# Ravel if ndims > 2 b/c no cythonized version available
609+
lib.convert_nans_to_NA(self._ndarray.ravel("K"))
610+
else:
611+
lib.convert_nans_to_NA(self._ndarray)
612+
613+
@classmethod
614+
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False):
615+
if dtype and not (isinstance(dtype, str) and dtype == "string"):
616+
dtype = pandas_dtype(dtype)
617+
assert isinstance(dtype, StringDtype) and dtype.storage == "python"
618+
619+
from pandas.core.arrays.masked import BaseMaskedArray
620+
621+
if isinstance(scalars, BaseMaskedArray):
622+
# avoid costly conversion to object dtype
623+
na_values = scalars._mask
624+
result = scalars._data
625+
result = lib.ensure_string_array(result, copy=copy, convert_na_value=False)
626+
result[na_values] = libmissing.NA
627+
628+
else:
629+
if hasattr(scalars, "type"):
630+
# pyarrow array; we cannot rely on the "to_numpy" check in
631+
# ensure_string_array because calling scalars.to_numpy would set
632+
# zero_copy_only to True which caused problems see GH#52076
633+
scalars = np.array(scalars)
634+
# convert non-na-likes to str, and nan-likes to StringDtype().na_value
635+
result = lib.ensure_string_array(scalars, na_value=libmissing.NA, copy=copy)
636+
637+
# Manually creating new array avoids the validation step in the __init__, so is
638+
# faster. Refactor need for validation?
639+
new_string_array = cls.__new__(cls)
640+
NDArrayBacked.__init__(
641+
new_string_array, result, StringDtype(storage=cls._storage)
642+
)
643+
644+
return new_string_array
645+
646+
647+
StringArray = ObjectStringArray
648+
649+
650+
class NumpyStringArray(BaseNumpyStringArray):
651+
_cache_dtype = get_string_dtype()
652+
_storage = "numpy"
653+
654+
@classmethod
655+
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False):
656+
result = np.array(scalars, dtype=cls._cache_dtype)
657+
658+
# Manually creating new array avoids the validation step in the __init__, so is
659+
# faster. Refactor need for validation?
660+
new_string_array = cls.__new__(cls)
661+
NDArrayBacked.__init__(
662+
new_string_array, result, StringDtype(storage=cls._storage)
663+
)
664+
665+
return new_string_array

pandas/core/construction.py

+1
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def sanitize_array(
536536
-------
537537
np.ndarray or ExtensionArray
538538
"""
539+
539540
if isinstance(data, ma.MaskedArray):
540541
data = sanitize_masked_array(data)
541542

pandas/core/dtypes/common.py

+13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Period,
1818
algos,
1919
lib,
20+
missing,
2021
)
2122
from pandas._libs.tslibs import conversion
2223
from pandas.util._exceptions import find_stack_level
@@ -518,6 +519,18 @@ def is_string_or_object_np_dtype(dtype: np.dtype) -> bool:
518519
return dtype == object or dtype.kind in "SU" or issubclass(dtype.type, str)
519520

520521

522+
def get_string_dtype():
523+
import os
524+
import sys
525+
526+
if not os.environ.get("NUMPY_EXPERIMENTAL_DTYPE_API", None) == "1":
527+
sys.exit()
528+
529+
import stringdtype
530+
531+
return stringdtype.StringDType(na_object=missing.NA)
532+
533+
521534
def is_string_dtype(arr_or_dtype) -> bool:
522535
"""
523536
Check whether the provided array or dtype is of the string dtype.

pandas/core/dtypes/missing.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
)
1212

1313
import numpy as np
14-
from stringdtype import StringDType
1514

1615
from pandas._config import get_option
1716

@@ -26,6 +25,7 @@
2625
DT64NS_DTYPE,
2726
TD64NS_DTYPE,
2827
ensure_object,
28+
get_string_dtype,
2929
is_scalar,
3030
is_string_or_object_np_dtype,
3131
)
@@ -300,6 +300,9 @@ def _isna_array(values: ArrayLike, inf_as_na: bool = False):
300300
return result
301301

302302

303+
StringDType = type(get_string_dtype())
304+
305+
303306
def _isna_string_dtype(values: np.ndarray, inf_as_na: bool) -> npt.NDArray[np.bool_]:
304307
# Working around NumPy ticket 1542
305308
dtype = values.dtype

pandas/core/indexes/base.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import warnings
2222

2323
import numpy as np
24-
from stringdtype import StringDType
2524

2625
from pandas._config import get_option
2726

@@ -507,7 +506,7 @@ def __new__(
507506
if isinstance(data, ABCMultiIndex):
508507
data = data._values
509508

510-
if data.dtype.kind not in "iufcbmM" and type(data.dtype) != StringDType:
509+
if data.dtype.kind not in "iufcbmM":
511510
# GH#11836 we need to avoid having numpy coerce
512511
# things that look like ints/floats to ints unless
513512
# they are actually ints, e.g. '0' and 0.0

pandas/core/strings/object_array.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import unicodedata
1515

1616
import numpy as np
17-
from stringdtype import StringDType
1817

1918
from pandas._libs import lib
2019
import pandas._libs.missing as libmissing
@@ -82,10 +81,7 @@ def _str_map(
8281

8382
arr = np.asarray(self)
8483
mask = isna(arr)
85-
type(arr.dtype)
86-
map_convert = (
87-
convert and not np.all(mask) and type(arr.dtype) is not StringDType
88-
)
84+
map_convert = convert and not np.all(mask)
8985
try:
9086
result = lib.map_infer_mask(arr, f, mask.view(np.uint8), map_convert)
9187
except (TypeError, AttributeError) as err:

0 commit comments

Comments
 (0)