From 0899e4ee4bc941f85fdc5307497e8d60ea39ba9b Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 2 Sep 2024 09:30:39 +0200 Subject: [PATCH 1/7] String dtype: implement _get_common_dtype --- pandas/core/arrays/string_.py | 31 ++++++++++++++++++++- pandas/tests/arrays/categorical/test_api.py | 3 -- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 143a13c54dbbb..d86d5141ab349 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -171,7 +171,7 @@ def __init__( # a consistent NaN value (and we can use `dtype.na_value is np.nan`) na_value = np.nan elif na_value is not libmissing.NA: - raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}") + raise ValueError(f"'na_value' must be np.nan or pd.NA, got {na_value}") self.storage = storage self._na_value = na_value @@ -284,6 +284,35 @@ def construct_array_type( # type: ignore[override] else: return ArrowStringArrayNumpySemantics + def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: + allowed_numpy_kinds = {"S", "U"} + + storages = set() + na_values = set() + + for dtype in dtypes: + if isinstance(dtype, StringDtype): + storages.add(dtype.storage) + na_values.add(dtype.na_value) + elif isinstance(dtype, np.dtype) and dtype.kind in allowed_numpy_kinds: + continue + else: + return None + + if len(storages) == 2: + # if both python and pyarrow storage -> priority to pyarrow + storage = "pyarrow" + else: + storage = next(iter(storages)) + + if len(na_values) == 2: + # if both NaN and NA -> priority to NA + na_value = libmissing.NA + else: + na_value = next(iter(na_values)) + + return StringDtype(storage=storage, na_value=na_value) + def __from_arrow__( self, array: pyarrow.Array | pyarrow.ChunkedArray ) -> BaseStringArray: diff --git a/pandas/tests/arrays/categorical/test_api.py b/pandas/tests/arrays/categorical/test_api.py index 2ccc5781c608e..2791fd55f54d7 100644 --- a/pandas/tests/arrays/categorical/test_api.py +++ b/pandas/tests/arrays/categorical/test_api.py @@ -3,8 +3,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas.compat import PY311 from pandas import ( @@ -151,7 +149,6 @@ def test_reorder_categories_raises(self, new_categories): with pytest.raises(ValueError, match=msg): cat.reorder_categories(new_categories) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_add_categories(self): cat = Categorical(["a", "b", "c", "a"], ordered=True) old = cat.copy() From 1bf8a69346d52833f7eab46574c4c72ddd772b29 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 2 Sep 2024 10:10:36 +0200 Subject: [PATCH 2/7] add specific tests --- pandas/core/arrays/string_.py | 4 +- pandas/tests/arrays/string_/test_concat.py | 68 ++++++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 pandas/tests/arrays/string_/test_concat.py diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index d86d5141ab349..d644f40e6da36 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -285,8 +285,6 @@ def construct_array_type( # type: ignore[override] return ArrowStringArrayNumpySemantics def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: - allowed_numpy_kinds = {"S", "U"} - storages = set() na_values = set() @@ -294,7 +292,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: if isinstance(dtype, StringDtype): storages.add(dtype.storage) na_values.add(dtype.na_value) - elif isinstance(dtype, np.dtype) and dtype.kind in allowed_numpy_kinds: + elif isinstance(dtype, np.dtype) and dtype.kind == "U": continue else: return None diff --git a/pandas/tests/arrays/string_/test_concat.py b/pandas/tests/arrays/string_/test_concat.py new file mode 100644 index 0000000000000..1d31d8940700b --- /dev/null +++ b/pandas/tests/arrays/string_/test_concat.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest + +from pandas.compat import HAS_PYARROW + +from pandas.core.dtypes.cast import find_common_type + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.parametrize( + "to_concat_dtypes, result_dtype", + [ + # same types + ([("pyarrow", pd.NA), ("pyarrow", pd.NA)], ("pyarrow", pd.NA)), + ([("pyarrow", np.nan), ("pyarrow", np.nan)], ("pyarrow", np.nan)), + ([("python", pd.NA), ("python", pd.NA)], ("python", pd.NA)), + ([("python", np.nan), ("python", np.nan)], ("python", np.nan)), + # pyarrow preference + ([("pyarrow", pd.NA), ("python", pd.NA)], ("pyarrow", pd.NA)), + # NA preference + ([("python", pd.NA), ("python", np.nan)], ("python", pd.NA)), + ], +) +def test_concat_series(request, to_concat_dtypes, result_dtype): + if any(storage == "pyarrow" for storage, _ in to_concat_dtypes) and not HAS_PYARROW: + pytest.skip("Could not import 'pyarrow'") + + ser_list = [ + pd.Series(["a", "b", None], dtype=pd.StringDtype(storage, na_value)) + for storage, na_value in to_concat_dtypes + ] + + result = pd.concat(ser_list, ignore_index=True) + expected = pd.Series( + ["a", "b", None, "a", "b", None], dtype=pd.StringDtype(*result_dtype) + ) + tm.assert_series_equal(result, expected) + + # order doesn't matter for result + result = pd.concat(ser_list[::1], ignore_index=True) + tm.assert_series_equal(result, expected) + + +def test_concat_with_object(string_dtype_arguments): + # _get_common_dtype cannot inspect values, so object dtype with strings still + # results in object dtype + result = pd.concat( + [ + pd.Series(["a", "b", None], dtype=pd.StringDtype(*string_dtype_arguments)), + pd.Series(["a", "b", None], dtype=object), + ] + ) + assert result.dtype == np.dtype("object") + + +def test_concat_with_numpy(string_dtype_arguments): + # common type with a numpy string dtype always preserves the pandas string dtype + dtype = pd.StringDtype(*string_dtype_arguments) + assert find_common_type([dtype, np.dtype("U")]) == dtype + assert find_common_type([np.dtype("U"), dtype]) == dtype + assert find_common_type([dtype, np.dtype("U10")]) == dtype + assert find_common_type([np.dtype("U10"), dtype]) == dtype + + # with any other numpy dtype -> object + assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object") + assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object") From 4132ce4ad6050bec3b6eff5b4b661bfeb7466c3e Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 2 Sep 2024 10:13:15 +0200 Subject: [PATCH 3/7] try fix typing --- pandas/core/arrays/string_.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index d644f40e6da36..dcb61f3142d07 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -297,12 +297,14 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: else: return None + storage: str if len(storages) == 2: # if both python and pyarrow storage -> priority to pyarrow storage = "pyarrow" else: storage = next(iter(storages)) + na_value: libmissing.NAType | float if len(na_values) == 2: # if both NaN and NA -> priority to NA na_value = libmissing.NA From 25f19750543bbc53efd470da397989a3a68ef445 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 2 Sep 2024 14:08:25 +0200 Subject: [PATCH 4/7] try fix typing --- pandas/core/arrays/string_.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index dcb61f3142d07..0311189739a83 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -173,7 +173,7 @@ def __init__( elif na_value is not libmissing.NA: raise ValueError(f"'na_value' must be np.nan or pd.NA, got {na_value}") - self.storage = storage + self.storage = cast(str, storage) self._na_value = na_value def __repr__(self) -> str: @@ -297,7 +297,6 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: else: return None - storage: str if len(storages) == 2: # if both python and pyarrow storage -> priority to pyarrow storage = "pyarrow" From 8921a6c76ab43d6d1eda7b4c8829ab43f5544c7a Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 2 Sep 2024 17:34:50 +0200 Subject: [PATCH 5/7] suppress typing error --- pandas/core/arrays/string_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 0311189739a83..2c572b9eb99f8 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -301,7 +301,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: # if both python and pyarrow storage -> priority to pyarrow storage = "pyarrow" else: - storage = next(iter(storages)) + storage = next(iter(storages)) # type: ignore[assignment] na_value: libmissing.NAType | float if len(na_values) == 2: From 8f181486481623836974ef577dcb5fc28e4815b8 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 6 Sep 2024 16:27:48 +0200 Subject: [PATCH 6/7] support numpy 2.0 string --- pandas/core/arrays/string_.py | 2 +- pandas/tests/arrays/string_/test_concat.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 2c572b9eb99f8..88fd1481031f8 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -292,7 +292,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: if isinstance(dtype, StringDtype): storages.add(dtype.storage) na_values.add(dtype.na_value) - elif isinstance(dtype, np.dtype) and dtype.kind == "U": + elif isinstance(dtype, np.dtype) and dtype.kind in ("U", "T"): continue else: return None diff --git a/pandas/tests/arrays/string_/test_concat.py b/pandas/tests/arrays/string_/test_concat.py index 1d31d8940700b..bd012a936f355 100644 --- a/pandas/tests/arrays/string_/test_concat.py +++ b/pandas/tests/arrays/string_/test_concat.py @@ -7,6 +7,7 @@ import pandas as pd import pandas._testing as tm +from pandas.util.version import Version @pytest.mark.parametrize( @@ -66,3 +67,7 @@ def test_concat_with_numpy(string_dtype_arguments): # with any other numpy dtype -> object assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object") assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object") + + if Version(np.__version__) >= Version("2"): + assert find_common_type([dtype, np.dtypes.StringDtype()]) == dtype + assert find_common_type([np.dtypes.StringDtype(), dtype]) == dtype From be915f7634b35a23d6a1b23cac5faf8e6cd5c564 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 6 Sep 2024 17:08:00 +0200 Subject: [PATCH 7/7] fix typo --- pandas/tests/arrays/string_/test_concat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/tests/arrays/string_/test_concat.py b/pandas/tests/arrays/string_/test_concat.py index bd012a936f355..320d700b2b6c3 100644 --- a/pandas/tests/arrays/string_/test_concat.py +++ b/pandas/tests/arrays/string_/test_concat.py @@ -69,5 +69,5 @@ def test_concat_with_numpy(string_dtype_arguments): assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object") if Version(np.__version__) >= Version("2"): - assert find_common_type([dtype, np.dtypes.StringDtype()]) == dtype - assert find_common_type([np.dtypes.StringDtype(), dtype]) == dtype + assert find_common_type([dtype, np.dtypes.StringDType()]) == dtype + assert find_common_type([np.dtypes.StringDType(), dtype]) == dtype