diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 143a13c54dbbb..88fd1481031f8 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -171,9 +171,9 @@ def __init__( # a consistent NaN value (and we can use `dtype.na_value is np.nan`) na_value = np.nan elif na_value is not libmissing.NA: - raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}") + raise ValueError(f"'na_value' must be np.nan or pd.NA, got {na_value}") - self.storage = storage + self.storage = cast(str, storage) self._na_value = na_value def __repr__(self) -> str: @@ -284,6 +284,34 @@ def construct_array_type( # type: ignore[override] else: return ArrowStringArrayNumpySemantics + def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: + storages = set() + na_values = set() + + for dtype in dtypes: + if isinstance(dtype, StringDtype): + storages.add(dtype.storage) + na_values.add(dtype.na_value) + elif isinstance(dtype, np.dtype) and dtype.kind in ("U", "T"): + continue + else: + return None + + if len(storages) == 2: + # if both python and pyarrow storage -> priority to pyarrow + storage = "pyarrow" + else: + storage = next(iter(storages)) # type: ignore[assignment] + + na_value: libmissing.NAType | float + if len(na_values) == 2: + # if both NaN and NA -> priority to NA + na_value = libmissing.NA + else: + na_value = next(iter(na_values)) + + return StringDtype(storage=storage, na_value=na_value) + def __from_arrow__( self, array: pyarrow.Array | pyarrow.ChunkedArray ) -> BaseStringArray: diff --git a/pandas/tests/arrays/categorical/test_api.py b/pandas/tests/arrays/categorical/test_api.py index 2ccc5781c608e..2791fd55f54d7 100644 --- a/pandas/tests/arrays/categorical/test_api.py +++ b/pandas/tests/arrays/categorical/test_api.py @@ -3,8 +3,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas.compat import PY311 from pandas import ( @@ -151,7 +149,6 @@ def test_reorder_categories_raises(self, new_categories): with pytest.raises(ValueError, match=msg): cat.reorder_categories(new_categories) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_add_categories(self): cat = Categorical(["a", "b", "c", "a"], ordered=True) old = cat.copy() diff --git a/pandas/tests/arrays/string_/test_concat.py b/pandas/tests/arrays/string_/test_concat.py new file mode 100644 index 0000000000000..320d700b2b6c3 --- /dev/null +++ b/pandas/tests/arrays/string_/test_concat.py @@ -0,0 +1,73 @@ +import numpy as np +import pytest + +from pandas.compat import HAS_PYARROW + +from pandas.core.dtypes.cast import find_common_type + +import pandas as pd +import pandas._testing as tm +from pandas.util.version import Version + + +@pytest.mark.parametrize( + "to_concat_dtypes, result_dtype", + [ + # same types + ([("pyarrow", pd.NA), ("pyarrow", pd.NA)], ("pyarrow", pd.NA)), + ([("pyarrow", np.nan), ("pyarrow", np.nan)], ("pyarrow", np.nan)), + ([("python", pd.NA), ("python", pd.NA)], ("python", pd.NA)), + ([("python", np.nan), ("python", np.nan)], ("python", np.nan)), + # pyarrow preference + ([("pyarrow", pd.NA), ("python", pd.NA)], ("pyarrow", pd.NA)), + # NA preference + ([("python", pd.NA), ("python", np.nan)], ("python", pd.NA)), + ], +) +def test_concat_series(request, to_concat_dtypes, result_dtype): + if any(storage == "pyarrow" for storage, _ in to_concat_dtypes) and not HAS_PYARROW: + pytest.skip("Could not import 'pyarrow'") + + ser_list = [ + pd.Series(["a", "b", None], dtype=pd.StringDtype(storage, na_value)) + for storage, na_value in to_concat_dtypes + ] + + result = pd.concat(ser_list, ignore_index=True) + expected = pd.Series( + ["a", "b", None, "a", "b", None], dtype=pd.StringDtype(*result_dtype) + ) + tm.assert_series_equal(result, expected) + + # order doesn't matter for result + result = pd.concat(ser_list[::1], ignore_index=True) + tm.assert_series_equal(result, expected) + + +def test_concat_with_object(string_dtype_arguments): + # _get_common_dtype cannot inspect values, so object dtype with strings still + # results in object dtype + result = pd.concat( + [ + pd.Series(["a", "b", None], dtype=pd.StringDtype(*string_dtype_arguments)), + pd.Series(["a", "b", None], dtype=object), + ] + ) + assert result.dtype == np.dtype("object") + + +def test_concat_with_numpy(string_dtype_arguments): + # common type with a numpy string dtype always preserves the pandas string dtype + dtype = pd.StringDtype(*string_dtype_arguments) + assert find_common_type([dtype, np.dtype("U")]) == dtype + assert find_common_type([np.dtype("U"), dtype]) == dtype + assert find_common_type([dtype, np.dtype("U10")]) == dtype + assert find_common_type([np.dtype("U10"), dtype]) == dtype + + # with any other numpy dtype -> object + assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object") + assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object") + + if Version(np.__version__) >= Version("2"): + assert find_common_type([dtype, np.dtypes.StringDType()]) == dtype + assert find_common_type([np.dtypes.StringDType(), dtype]) == dtype