Skip to content

Commit 1272cb1

Browse files
String dtype: implement object-dtype based StringArray variant with NumPy semantics (#58451)
Co-authored-by: Patrick Hoefler <[email protected]>
1 parent aa134bb commit 1272cb1

File tree

14 files changed

+231
-48
lines changed

14 files changed

+231
-48
lines changed

pandas/_libs/lib.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,7 @@ def maybe_convert_objects(ndarray[object] objects,
27022702
if using_string_dtype() and is_string_array(objects, skipna=True):
27032703
from pandas.core.arrays.string_ import StringDtype
27042704

2705-
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
2705+
dtype = StringDtype(na_value=np.nan)
27062706
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
27072707

27082708
elif convert_to_nullable_dtype and is_string_array(objects, skipna=True):

pandas/_testing/asserters.py

+18
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,24 @@ def assert_extension_array_equal(
796796
left_na, right_na, obj=f"{obj} NA mask", index_values=index_values
797797
)
798798

799+
# Specifically for StringArrayNumpySemantics, validate here we have a valid array
800+
if (
801+
isinstance(left.dtype, StringDtype)
802+
and left.dtype.storage == "python"
803+
and left.dtype.na_value is np.nan
804+
):
805+
assert np.all(
806+
[np.isnan(val) for val in left._ndarray[left_na]] # type: ignore[attr-defined]
807+
), "wrong missing value sentinels"
808+
if (
809+
isinstance(right.dtype, StringDtype)
810+
and right.dtype.storage == "python"
811+
and right.dtype.na_value is np.nan
812+
):
813+
assert np.all(
814+
[np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined]
815+
), "wrong missing value sentinels"
816+
799817
left_valid = left[~left_na].to_numpy(dtype=object)
800818
right_valid = right[~right_na].to_numpy(dtype=object)
801819
if check_exact:

pandas/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from pandas.compat.numpy import is_numpy_dev
2727
from pandas.compat.pyarrow import (
28+
HAS_PYARROW,
2829
pa_version_under10p1,
2930
pa_version_under11p0,
3031
pa_version_under13p0,
@@ -156,6 +157,7 @@ def is_ci_environment() -> bool:
156157
"pa_version_under14p1",
157158
"pa_version_under16p0",
158159
"pa_version_under17p0",
160+
"HAS_PYARROW",
159161
"IS64",
160162
"ISMUSL",
161163
"PY311",

pandas/compat/pyarrow.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
pa_version_under15p0 = _palv < Version("15.0.0")
1818
pa_version_under16p0 = _palv < Version("16.0.0")
1919
pa_version_under17p0 = _palv < Version("17.0.0")
20+
HAS_PYARROW = True
2021
except ImportError:
2122
pa_version_under10p1 = True
2223
pa_version_under11p0 = True
@@ -27,3 +28,4 @@
2728
pa_version_under15p0 = True
2829
pa_version_under16p0 = True
2930
pa_version_under17p0 = True
31+
HAS_PYARROW = False

pandas/conftest.py

+4
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,7 @@ def string_storage(request):
13131313
("python", pd.NA),
13141314
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
13151315
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
1316+
("python", np.nan),
13161317
]
13171318
)
13181319
def string_dtype_arguments(request):
@@ -1374,12 +1375,14 @@ def object_dtype(request):
13741375
("python", pd.NA),
13751376
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
13761377
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
1378+
("python", np.nan),
13771379
],
13781380
ids=[
13791381
"string=object",
13801382
"string=string[python]",
13811383
"string=string[pyarrow]",
13821384
"string=str[pyarrow]",
1385+
"string=str[python]",
13831386
],
13841387
)
13851388
def any_string_dtype(request):
@@ -1389,6 +1392,7 @@ def any_string_dtype(request):
13891392
* 'string[python]' (NA variant)
13901393
* 'string[pyarrow]' (NA variant)
13911394
* 'str' (NaN variant, with pyarrow)
1395+
* 'str' (NaN variant, without pyarrow)
13921396
"""
13931397
if isinstance(request.param, np.dtype):
13941398
return request.param

0 commit comments

Comments
 (0)