diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 2650d60eb3cef..0bb47541e5963 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -2702,7 +2702,7 @@ def maybe_convert_objects(ndarray[object] objects, if using_string_dtype() and is_string_array(objects, skipna=True): from pandas.core.arrays.string_ import StringDtype - dtype = StringDtype(storage="pyarrow_numpy") + dtype = StringDtype(storage="pyarrow", na_value=np.nan) return dtype.construct_array_type()._from_sequence(objects, dtype=dtype) elif convert_to_nullable_dtype and is_string_array(objects, skipna=True): diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index 1cd91ee5b120c..3aa53d4b07aa5 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -509,14 +509,14 @@ def shares_memory(left, right) -> bool: if ( isinstance(left, ExtensionArray) and is_string_dtype(left.dtype) - and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined] + and left.dtype.storage == "pyarrow" # type: ignore[attr-defined] ): # https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669 left = cast("ArrowExtensionArray", left) if ( isinstance(right, ExtensionArray) and is_string_dtype(right.dtype) - and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined] + and right.dtype.storage == "pyarrow" # type: ignore[attr-defined] ): right = cast("ArrowExtensionArray", right) left_pa_data = left._pa_array diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 5da479760047f..a17056b51a014 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -575,10 +575,8 @@ def __getitem__(self, item: PositionalIndexer): if isinstance(item, np.ndarray): if not len(item): # Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string] - if self._dtype.name == "string" and self._dtype.storage in ( - "pyarrow", - "pyarrow_numpy", - ): + if self._dtype.name == "string" and self._dtype.storage == "pyarrow": + # TODO(infer_string) should this be large_string? pa_dtype = pa.string() else: pa_dtype = self._dtype.pyarrow_dtype diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 13c26f0c97934..cae770d85637c 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -9,7 +9,10 @@ import numpy as np -from pandas._config import get_option +from pandas._config import ( + get_option, + using_string_dtype, +) from pandas._libs import ( lib, @@ -81,8 +84,10 @@ class StringDtype(StorageExtensionDtype): Parameters ---------- - storage : {"python", "pyarrow", "pyarrow_numpy"}, optional + storage : {"python", "pyarrow"}, optional If not given, the value of ``pd.options.mode.string_storage``. + na_value : {np.nan, pd.NA}, default pd.NA + Whether the dtype follows NaN or NA missing value semantics. Attributes ---------- @@ -113,30 +118,67 @@ class StringDtype(StorageExtensionDtype): # follows NumPy semantics, which uses nan. @property def na_value(self) -> libmissing.NAType | float: # type: ignore[override] - if self.storage == "pyarrow_numpy": - return np.nan - else: - return libmissing.NA + return self._na_value - _metadata = ("storage",) + _metadata = ("storage", "_na_value") # type: ignore[assignment] - def __init__(self, storage=None) -> None: + def __init__( + self, + storage: str | None = None, + na_value: libmissing.NAType | float = libmissing.NA, + ) -> None: + # infer defaults if storage is None: - infer_string = get_option("future.infer_string") - if infer_string: - storage = "pyarrow_numpy" + if using_string_dtype(): + storage = "pyarrow" else: storage = get_option("mode.string_storage") - if storage not in {"python", "pyarrow", "pyarrow_numpy"}: + + if storage == "pyarrow_numpy": + # TODO raise a deprecation warning + storage = "pyarrow" + na_value = np.nan + + # validate options + if storage not in {"python", "pyarrow"}: raise ValueError( - f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. " - f"Got {storage} instead." + f"Storage must be 'python' or 'pyarrow'. Got {storage} instead." ) - if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under10p1: + if storage == "pyarrow" and pa_version_under10p1: raise ImportError( "pyarrow>=10.0.1 is required for PyArrow backed StringArray." ) + + if isinstance(na_value, float) and np.isnan(na_value): + # when passed a NaN value, always set to np.nan to ensure we use + # a consistent NaN value (and we can use `dtype.na_value is np.nan`) + na_value = np.nan + elif na_value is not libmissing.NA: + raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}") + self.storage = storage + self._na_value = na_value + + def __eq__(self, other: object) -> bool: + # we need to override the base class __eq__ because na_value (NA or NaN) + # cannot be checked with normal `==` + if isinstance(other, str): + if other == self.name: + return True + try: + other = self.construct_from_string(other) + except TypeError: + return False + if isinstance(other, type(self)): + return self.storage == other.storage and self.na_value is other.na_value + return False + + def __hash__(self) -> int: + # need to override __hash__ as well because of overriding __eq__ + return super().__hash__() + + def __reduce__(self): + return StringDtype, (self.storage, self.na_value) @property def type(self) -> type[str]: @@ -181,6 +223,7 @@ def construct_from_string(cls, string) -> Self: elif string == "string[pyarrow]": return cls(storage="pyarrow") elif string == "string[pyarrow_numpy]": + # TODO deprecate return cls(storage="pyarrow_numpy") else: raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'") @@ -205,7 +248,7 @@ def construct_array_type( # type: ignore[override] if self.storage == "python": return StringArray - elif self.storage == "pyarrow": + elif self.storage == "pyarrow" and self._na_value is libmissing.NA: return ArrowStringArray else: return ArrowStringArrayNumpySemantics @@ -217,13 +260,17 @@ def __from_arrow__( Construct StringArray from pyarrow Array/ChunkedArray. """ if self.storage == "pyarrow": - from pandas.core.arrays.string_arrow import ArrowStringArray + if self._na_value is libmissing.NA: + from pandas.core.arrays.string_arrow import ArrowStringArray + + return ArrowStringArray(array) + else: + from pandas.core.arrays.string_arrow import ( + ArrowStringArrayNumpySemantics, + ) - return ArrowStringArray(array) - elif self.storage == "pyarrow_numpy": - from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics + return ArrowStringArrayNumpySemantics(array) - return ArrowStringArrayNumpySemantics(array) else: import pyarrow diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 97c06149d0b7e..869cc34d5f61d 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -131,6 +131,7 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr # base class "ArrowExtensionArray" defined the type as "ArrowDtype") _dtype: StringDtype # type: ignore[assignment] _storage = "pyarrow" + _na_value: libmissing.NAType | float = libmissing.NA def __init__(self, values) -> None: _chk_pyarrow_available() @@ -140,7 +141,7 @@ def __init__(self, values) -> None: values = pc.cast(values, pa.large_string()) super().__init__(values) - self._dtype = StringDtype(storage=self._storage) + self._dtype = StringDtype(storage=self._storage, na_value=self._na_value) if not pa.types.is_large_string(self._pa_array.type) and not ( pa.types.is_dictionary(self._pa_array.type) @@ -187,10 +188,7 @@ def _from_sequence( if dtype and not (isinstance(dtype, str) and dtype == "string"): dtype = pandas_dtype(dtype) - assert isinstance(dtype, StringDtype) and dtype.storage in ( - "pyarrow", - "pyarrow_numpy", - ) + assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow" if isinstance(scalars, BaseMaskedArray): # avoid costly conversion to object dtype in ensure_string_array and @@ -597,7 +595,8 @@ def _rank( class ArrowStringArrayNumpySemantics(ArrowStringArray): - _storage = "pyarrow_numpy" + _storage = "pyarrow" + _na_value = np.nan @classmethod def _result_converter(cls, values, na=None): diff --git a/pandas/core/construction.py b/pandas/core/construction.py index 32792aa7f0543..81aeb40f375b0 100644 --- a/pandas/core/construction.py +++ b/pandas/core/construction.py @@ -574,7 +574,7 @@ def sanitize_array( if isinstance(data, str) and using_string_dtype() and original_dtype is None: from pandas.core.arrays.string_ import StringDtype - dtype = StringDtype("pyarrow_numpy") + dtype = StringDtype("pyarrow", na_value=np.nan) data = construct_1d_arraylike_from_scalar(data, len(index), dtype) return data @@ -608,7 +608,7 @@ def sanitize_array( elif data.dtype.kind == "U" and using_string_dtype(): from pandas.core.arrays.string_ import StringDtype - dtype = StringDtype(storage="pyarrow_numpy") + dtype = StringDtype(storage="pyarrow", na_value=np.nan) subarr = dtype.construct_array_type()._from_sequence(data, dtype=dtype) if subarr is data and copy: diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 21e45505b40fc..d750451a1ca84 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -801,7 +801,7 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]: if using_string_dtype(): from pandas.core.arrays.string_ import StringDtype - dtype = StringDtype(storage="pyarrow_numpy") + dtype = StringDtype(storage="pyarrow", na_value=np.nan) elif isinstance(val, (np.datetime64, dt.datetime)): try: diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index e67c59c86dd0c..50f44cc728aea 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -5453,9 +5453,10 @@ def equals(self, other: Any) -> bool: if ( isinstance(self.dtype, StringDtype) - and self.dtype.storage == "pyarrow_numpy" + and self.dtype.na_value is np.nan and other.dtype != self.dtype ): + # TODO(infer_string) can we avoid this special case? # special case for object behavior return other.equals(self.astype(object)) diff --git a/pandas/core/internals/construction.py b/pandas/core/internals/construction.py index c31479b3011e5..08e1650a5de12 100644 --- a/pandas/core/internals/construction.py +++ b/pandas/core/internals/construction.py @@ -302,7 +302,7 @@ def ndarray_to_mgr( nb = new_block_2d(values, placement=bp, refs=refs) block_values = [nb] elif dtype is None and values.dtype.kind == "U" and using_string_dtype(): - dtype = StringDtype(storage="pyarrow_numpy") + dtype = StringDtype(storage="pyarrow", na_value=np.nan) obj_columns = list(values) block_values = [ diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 9d88e61951e99..c397c1c2566a5 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -10,6 +10,7 @@ import numpy as np +from pandas._libs import missing as libmissing from pandas._libs.sparse import IntIndex from pandas.core.dtypes.common import ( @@ -256,7 +257,7 @@ def _get_dummies_1d( dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment] elif ( isinstance(input_dtype, StringDtype) - and input_dtype.storage != "pyarrow_numpy" + and input_dtype.na_value is libmissing.NA ): dtype = pandas_dtype("boolean") # type: ignore[assignment] else: diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 2ce77ac19b9c5..6364072fd215c 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -2677,8 +2677,7 @@ def _factorize_keys( elif isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype: if (isinstance(lk.dtype, ArrowDtype) and is_string_dtype(lk.dtype)) or ( - isinstance(lk.dtype, StringDtype) - and lk.dtype.storage in ["pyarrow", "pyarrow_numpy"] + isinstance(lk.dtype, StringDtype) and lk.dtype.storage == "pyarrow" ): import pyarrow as pa import pyarrow.compute as pc diff --git a/pandas/core/tools/numeric.py b/pandas/core/tools/numeric.py index 3d406d3bfb115..26e73794af298 100644 --- a/pandas/core/tools/numeric.py +++ b/pandas/core/tools/numeric.py @@ -7,7 +7,10 @@ import numpy as np -from pandas._libs import lib +from pandas._libs import ( + lib, + missing as libmissing, +) from pandas.util._validators import check_dtype_backend from pandas.core.dtypes.cast import maybe_downcast_numeric @@ -218,7 +221,7 @@ def to_numeric( coerce_numeric=coerce_numeric, convert_to_masked_nullable=dtype_backend is not lib.no_default or isinstance(values_dtype, StringDtype) - and not values_dtype.storage == "pyarrow_numpy", + and values_dtype.na_value is libmissing.NA, ) if new_mask is not None: @@ -229,7 +232,7 @@ def to_numeric( dtype_backend is not lib.no_default and new_mask is None or isinstance(values_dtype, StringDtype) - and not values_dtype.storage == "pyarrow_numpy" + and values_dtype.na_value is libmissing.NA ): new_mask = np.zeros(values.shape, dtype=np.bool_) diff --git a/pandas/io/_util.py b/pandas/io/_util.py index cb0f89945e440..a72a16269959d 100644 --- a/pandas/io/_util.py +++ b/pandas/io/_util.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +import numpy as np + from pandas.compat._optional import import_optional_dependency import pandas as pd @@ -32,6 +34,6 @@ def arrow_string_types_mapper() -> Callable: pa = import_optional_dependency("pyarrow") return { - pa.string(): pd.StringDtype(storage="pyarrow_numpy"), - pa.large_string(): pd.StringDtype(storage="pyarrow_numpy"), + pa.string(): pd.StringDtype(storage="pyarrow", na_value=np.nan), + pa.large_string(): pd.StringDtype(storage="pyarrow", na_value=np.nan), }.get diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 597b407a29c94..7757847f3c841 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -20,13 +20,6 @@ ) -def na_val(dtype): - if dtype.storage == "pyarrow_numpy": - return np.nan - else: - return pd.NA - - @pytest.fixture def dtype(string_storage): """Fixture giving StringDtype from parametrized 'string_storage'""" @@ -39,24 +32,45 @@ def cls(dtype): return dtype.construct_array_type() +def test_dtype_equality(): + pytest.importorskip("pyarrow") + + dtype1 = pd.StringDtype("python") + dtype2 = pd.StringDtype("pyarrow") + dtype3 = pd.StringDtype("pyarrow", na_value=np.nan) + + assert dtype1 == pd.StringDtype("python", na_value=pd.NA) + assert dtype1 != dtype2 + assert dtype1 != dtype3 + + assert dtype2 == pd.StringDtype("pyarrow", na_value=pd.NA) + assert dtype2 != dtype1 + assert dtype2 != dtype3 + + assert dtype3 == pd.StringDtype("pyarrow", na_value=np.nan) + assert dtype3 == pd.StringDtype("pyarrow", na_value=float("nan")) + assert dtype3 != dtype1 + assert dtype3 != dtype2 + + def test_repr(dtype): df = pd.DataFrame({"A": pd.array(["a", pd.NA, "b"], dtype=dtype)}) - if dtype.storage == "pyarrow_numpy": + if dtype.na_value is np.nan: expected = " A\n0 a\n1 NaN\n2 b" else: expected = " A\n0 a\n1 \n2 b" assert repr(df) == expected - if dtype.storage == "pyarrow_numpy": + if dtype.na_value is np.nan: expected = "0 a\n1 NaN\n2 b\nName: A, dtype: string" else: expected = "0 a\n1 \n2 b\nName: A, dtype: string" assert repr(df.A) == expected - if dtype.storage == "pyarrow": + if dtype.storage == "pyarrow" and dtype.na_value is pd.NA: arr_name = "ArrowStringArray" expected = f"<{arr_name}>\n['a', , 'b']\nLength: 3, dtype: string" - elif dtype.storage == "pyarrow_numpy": + elif dtype.storage == "pyarrow" and dtype.na_value is np.nan: arr_name = "ArrowStringArrayNumpySemantics" expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string" else: @@ -68,7 +82,7 @@ def test_repr(dtype): def test_none_to_nan(cls, dtype): a = cls._from_sequence(["a", None, "b"], dtype=dtype) assert a[1] is not None - assert a[1] is na_val(a.dtype) + assert a[1] is a.dtype.na_value def test_setitem_validates(cls, dtype): @@ -225,7 +239,7 @@ def test_comparison_methods_scalar(comparison_op, dtype): a = pd.array(["a", None, "c"], dtype=dtype) other = "a" result = getattr(a, op_name)(other) - if dtype.storage == "pyarrow_numpy": + if dtype.na_value is np.nan: expected = np.array([getattr(item, op_name)(other) for item in a]) if comparison_op == operator.ne: expected[1] = True @@ -244,7 +258,7 @@ def test_comparison_methods_scalar_pd_na(comparison_op, dtype): a = pd.array(["a", None, "c"], dtype=dtype) result = getattr(a, op_name)(pd.NA) - if dtype.storage == "pyarrow_numpy": + if dtype.na_value is np.nan: if operator.ne == comparison_op: expected = np.array([True, True, True]) else: @@ -271,7 +285,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype): result = getattr(a, op_name)(other) - if dtype.storage == "pyarrow_numpy": + if dtype.na_value is np.nan: expected_data = { "__eq__": [False, False, False], "__ne__": [True, True, True], @@ -293,7 +307,7 @@ def test_comparison_methods_array(comparison_op, dtype): a = pd.array(["a", None, "c"], dtype=dtype) other = [None, None, "c"] result = getattr(a, op_name)(other) - if dtype.storage == "pyarrow_numpy": + if dtype.na_value is np.nan: if operator.ne == comparison_op: expected = np.array([True, True, False]) else: @@ -387,7 +401,7 @@ def test_astype_int(dtype): tm.assert_numpy_array_equal(result, expected) arr = pd.array(["1", pd.NA, "3"], dtype=dtype) - if dtype.storage == "pyarrow_numpy": + if dtype.na_value is np.nan: err = ValueError msg = "cannot convert float NaN to integer" else: @@ -441,7 +455,7 @@ def test_min_max(method, skipna, dtype): expected = "a" if method == "min" else "c" assert result == expected else: - assert result is na_val(arr.dtype) + assert result is arr.dtype.na_value @pytest.mark.parametrize("method", ["min", "max"]) @@ -490,7 +504,7 @@ def test_arrow_array(dtype): data = pd.array(["a", "b", "c"], dtype=dtype) arr = pa.array(data) expected = pa.array(list(data), type=pa.large_string(), from_pandas=True) - if dtype.storage in ("pyarrow", "pyarrow_numpy") and pa_version_under12p0: + if dtype.storage == "pyarrow" and pa_version_under12p0: expected = pa.chunked_array(expected) if dtype.storage == "python": expected = pc.cast(expected, pa.string()) @@ -522,7 +536,7 @@ def test_arrow_roundtrip(dtype, string_storage2, request, using_infer_string): expected = df.astype(f"string[{string_storage2}]") tm.assert_frame_equal(result, expected) # ensure the missing value is represented by NA and not np.nan or None - assert result.loc[2, "a"] is na_val(result["a"].dtype) + assert result.loc[2, "a"] is result["a"].dtype.na_value @pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning") @@ -556,10 +570,10 @@ def test_arrow_load_from_zero_chunks( def test_value_counts_na(dtype): - if getattr(dtype, "storage", "") == "pyarrow": - exp_dtype = "int64[pyarrow]" - elif getattr(dtype, "storage", "") == "pyarrow_numpy": + if dtype.na_value is np.nan: exp_dtype = "int64" + elif dtype.storage == "pyarrow": + exp_dtype = "int64[pyarrow]" else: exp_dtype = "Int64" arr = pd.array(["a", "b", "a", pd.NA], dtype=dtype) @@ -573,10 +587,10 @@ def test_value_counts_na(dtype): def test_value_counts_with_normalize(dtype): - if getattr(dtype, "storage", "") == "pyarrow": - exp_dtype = "double[pyarrow]" - elif getattr(dtype, "storage", "") == "pyarrow_numpy": + if dtype.na_value is np.nan: exp_dtype = np.float64 + elif dtype.storage == "pyarrow": + exp_dtype = "double[pyarrow]" else: exp_dtype = "Float64" ser = pd.Series(["a", "b", "a", pd.NA], dtype=dtype) @@ -586,10 +600,10 @@ def test_value_counts_with_normalize(dtype): def test_value_counts_sort_false(dtype): - if getattr(dtype, "storage", "") == "pyarrow": - exp_dtype = "int64[pyarrow]" - elif getattr(dtype, "storage", "") == "pyarrow_numpy": + if dtype.na_value is np.nan: exp_dtype = "int64" + elif dtype.storage == "pyarrow": + exp_dtype = "int64[pyarrow]" else: exp_dtype = "Int64" ser = pd.Series(["a", "b", "c", "b"], dtype=dtype) @@ -621,7 +635,7 @@ def test_astype_from_float_dtype(float_dtype, dtype): def test_to_numpy_returns_pdna_default(dtype): arr = pd.array(["a", pd.NA, "b"], dtype=dtype) result = np.array(arr) - expected = np.array(["a", na_val(dtype), "b"], dtype=object) + expected = np.array(["a", dtype.na_value, "b"], dtype=object) tm.assert_numpy_array_equal(result, expected) @@ -661,7 +675,7 @@ def test_setitem_scalar_with_mask_validation(dtype): mask = np.array([False, True, False]) ser[mask] = None - assert ser.array[1] is na_val(ser.dtype) + assert ser.array[1] is ser.dtype.na_value # for other non-string we should also raise an error ser = pd.Series(["a", "b", "c"], dtype=dtype) diff --git a/pandas/tests/arrays/string_/test_string_arrow.py b/pandas/tests/arrays/string_/test_string_arrow.py index 405c1c217b04d..c610ef5315723 100644 --- a/pandas/tests/arrays/string_/test_string_arrow.py +++ b/pandas/tests/arrays/string_/test_string_arrow.py @@ -29,6 +29,8 @@ def test_eq_all_na(): def test_config(string_storage, request, using_infer_string): if using_infer_string and string_storage != "pyarrow_numpy": request.applymarker(pytest.mark.xfail(reason="infer string takes precedence")) + if string_storage == "pyarrow_numpy": + request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)")) with pd.option_context("string_storage", string_storage): assert StringDtype().storage == string_storage result = pd.array(["a", "b"]) @@ -260,6 +262,6 @@ def test_pickle_roundtrip(dtype): def test_string_dtype_error_message(): # GH#55051 pytest.importorskip("pyarrow") - msg = "Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'." + msg = "Storage must be 'python' or 'pyarrow'." with pytest.raises(ValueError, match=msg): StringDtype("bla") diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index b7f0f973e640a..dd2ed0bd62a02 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -66,14 +66,14 @@ def test_value_counts_with_normalize(self, data): expected = pd.Series(0.0, index=result.index, name="proportion") expected[result > 0] = 1 / len(values) - if getattr(data.dtype, "storage", "") == "pyarrow" or isinstance( + if isinstance(data.dtype, pd.StringDtype) and data.dtype.na_value is np.nan: + # TODO: avoid special-casing + expected = expected.astype("float64") + elif getattr(data.dtype, "storage", "") == "pyarrow" or isinstance( data.dtype, pd.ArrowDtype ): # TODO: avoid special-casing expected = expected.astype("double[pyarrow]") - elif getattr(data.dtype, "storage", "") == "pyarrow_numpy": - # TODO: avoid special-casing - expected = expected.astype("float64") elif na_value_for_dtype(data.dtype) is pd.NA: # TODO(GH#44692): avoid special-casing expected = expected.astype("Float64") diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 49ad3fce92a5c..4628c5568b49b 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -96,9 +96,15 @@ def data_for_grouping(dtype, chunked): class TestStringArray(base.ExtensionTests): def test_eq_with_str(self, dtype): - assert dtype == f"string[{dtype.storage}]" super().test_eq_with_str(dtype) + if dtype.na_value is pd.NA: + # only the NA-variant supports parametrized string alias + assert dtype == f"string[{dtype.storage}]" + elif dtype.storage == "pyarrow": + # TODO(infer_string) deprecate this + assert dtype == "string[pyarrow_numpy]" + def test_is_not_string_type(self, dtype): # Different from BaseDtypeTests.test_is_not_string_type # because StringDtype is a string type @@ -140,28 +146,21 @@ def _get_expected_exception( self, op_name: str, obj, other ) -> type[Exception] | None: if op_name in ["__divmod__", "__rdivmod__"]: - if isinstance(obj, pd.Series) and cast( - StringDtype, tm.get_dtype(obj) - ).storage in [ - "pyarrow", - "pyarrow_numpy", - ]: + if ( + isinstance(obj, pd.Series) + and cast(StringDtype, tm.get_dtype(obj)).storage == "pyarrow" + ): # TODO: re-raise as TypeError? return NotImplementedError - elif isinstance(other, pd.Series) and cast( - StringDtype, tm.get_dtype(other) - ).storage in [ - "pyarrow", - "pyarrow_numpy", - ]: + elif ( + isinstance(other, pd.Series) + and cast(StringDtype, tm.get_dtype(other)).storage == "pyarrow" + ): # TODO: re-raise as TypeError? return NotImplementedError return TypeError elif op_name in ["__mod__", "__rmod__", "__pow__", "__rpow__"]: - if cast(StringDtype, tm.get_dtype(obj)).storage in [ - "pyarrow", - "pyarrow_numpy", - ]: + if cast(StringDtype, tm.get_dtype(obj)).storage == "pyarrow": return NotImplementedError return TypeError elif op_name in ["__mul__", "__rmul__"]: @@ -175,10 +174,7 @@ def _get_expected_exception( "__sub__", "__rsub__", ]: - if cast(StringDtype, tm.get_dtype(obj)).storage in [ - "pyarrow", - "pyarrow_numpy", - ]: + if cast(StringDtype, tm.get_dtype(obj)).storage == "pyarrow": import pyarrow as pa # TODO: better to re-raise as TypeError? @@ -190,7 +186,7 @@ def _get_expected_exception( def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: return ( op_name in ["min", "max"] - or ser.dtype.storage == "pyarrow_numpy" # type: ignore[union-attr] + or ser.dtype.na_value is np.nan # type: ignore[union-attr] and op_name in ("any", "all") ) @@ -198,10 +194,10 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): dtype = cast(StringDtype, tm.get_dtype(obj)) if op_name in ["__add__", "__radd__"]: cast_to = dtype + elif dtype.na_value is np.nan: + cast_to = np.bool_ # type: ignore[assignment] elif dtype.storage == "pyarrow": cast_to = "boolean[pyarrow]" # type: ignore[assignment] - elif dtype.storage == "pyarrow_numpy": - cast_to = np.bool_ # type: ignore[assignment] else: cast_to = "boolean" # type: ignore[assignment] return pointwise_result.astype(cast_to) diff --git a/pandas/tests/frame/methods/test_convert_dtypes.py b/pandas/tests/frame/methods/test_convert_dtypes.py index 521d2cb14ac6a..9cbbebf35b2d1 100644 --- a/pandas/tests/frame/methods/test_convert_dtypes.py +++ b/pandas/tests/frame/methods/test_convert_dtypes.py @@ -18,6 +18,7 @@ def test_convert_dtypes( # Just check that it works for DataFrame here if using_infer_string: string_storage = "pyarrow_numpy" + df = pd.DataFrame( { "a": pd.Series([1, 2, 3], dtype=np.dtype("int32")), diff --git a/pandas/tests/series/test_constructors.py b/pandas/tests/series/test_constructors.py index 44a7862c21273..91cf1708ed43b 100644 --- a/pandas/tests/series/test_constructors.py +++ b/pandas/tests/series/test_constructors.py @@ -2113,9 +2113,12 @@ def test_series_string_inference_array_string_dtype(self): tm.assert_series_equal(ser, expected) def test_series_string_inference_storage_definition(self): - # GH#54793 + # https://github.com/pandas-dev/pandas/issues/54793 + # but after PDEP-14 (string dtype), it was decided to keep dtype="string" + # returning the NA string dtype, so expected is changed from + # "string[pyarrow_numpy]" to "string[pyarrow]" pytest.importorskip("pyarrow") - expected = Series(["a", "b"], dtype="string[pyarrow_numpy]") + expected = Series(["a", "b"], dtype="string[pyarrow]") with pd.option_context("future.infer_string", True): result = Series(["a", "b"], dtype="string") tm.assert_series_equal(result, expected) diff --git a/pandas/tests/strings/__init__.py b/pandas/tests/strings/__init__.py index 01b49b5e5b633..e94f656fc9823 100644 --- a/pandas/tests/strings/__init__.py +++ b/pandas/tests/strings/__init__.py @@ -7,7 +7,7 @@ def _convert_na_value(ser, expected): if ser.dtype != object: - if ser.dtype.storage == "pyarrow_numpy": + if ser.dtype.na_value is np.nan: expected = expected.fillna(np.nan) else: # GH#18463