Skip to content

String dtype: implement _get_common_dtype #59682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
32 changes: 30 additions & 2 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def __init__(
# a consistent NaN value (and we can use `dtype.na_value is np.nan`)
na_value = np.nan
elif na_value is not libmissing.NA:
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
raise ValueError(f"'na_value' must be np.nan or pd.NA, got {na_value}")

self.storage = storage
self.storage = cast(str, storage)
self._na_value = na_value

def __repr__(self) -> str:
Expand Down Expand Up @@ -284,6 +284,34 @@ def construct_array_type( # type: ignore[override]
else:
return ArrowStringArrayNumpySemantics

def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
storages = set()
na_values = set()

for dtype in dtypes:
if isinstance(dtype, StringDtype):
storages.add(dtype.storage)
na_values.add(dtype.na_value)
elif isinstance(dtype, np.dtype) and dtype.kind in ("U", "T"):
continue
else:
return None

if len(storages) == 2:
# if both python and pyarrow storage -> priority to pyarrow
storage = "pyarrow"
else:
storage = next(iter(storages)) # type: ignore[assignment]

na_value: libmissing.NAType | float
if len(na_values) == 2:
# if both NaN and NA -> priority to NA
na_value = libmissing.NA
else:
na_value = next(iter(na_values))

return StringDtype(storage=storage, na_value=na_value)

def __from_arrow__(
self, array: pyarrow.Array | pyarrow.ChunkedArray
) -> BaseStringArray:
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/arrays/categorical/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import PY311

from pandas import (
Expand Down Expand Up @@ -151,7 +149,6 @@ def test_reorder_categories_raises(self, new_categories):
with pytest.raises(ValueError, match=msg):
cat.reorder_categories(new_categories)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_add_categories(self):
cat = Categorical(["a", "b", "c", "a"], ordered=True)
old = cat.copy()
Expand Down
73 changes: 73 additions & 0 deletions pandas/tests/arrays/string_/test_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy as np
import pytest

from pandas.compat import HAS_PYARROW

from pandas.core.dtypes.cast import find_common_type

import pandas as pd
import pandas._testing as tm
from pandas.util.version import Version


@pytest.mark.parametrize(
"to_concat_dtypes, result_dtype",
[
# same types
([("pyarrow", pd.NA), ("pyarrow", pd.NA)], ("pyarrow", pd.NA)),
([("pyarrow", np.nan), ("pyarrow", np.nan)], ("pyarrow", np.nan)),
([("python", pd.NA), ("python", pd.NA)], ("python", pd.NA)),
([("python", np.nan), ("python", np.nan)], ("python", np.nan)),
# pyarrow preference
([("pyarrow", pd.NA), ("python", pd.NA)], ("pyarrow", pd.NA)),
# NA preference
([("python", pd.NA), ("python", np.nan)], ("python", pd.NA)),
],
)
def test_concat_series(request, to_concat_dtypes, result_dtype):
if any(storage == "pyarrow" for storage, _ in to_concat_dtypes) and not HAS_PYARROW:
pytest.skip("Could not import 'pyarrow'")

ser_list = [
pd.Series(["a", "b", None], dtype=pd.StringDtype(storage, na_value))
for storage, na_value in to_concat_dtypes
]

result = pd.concat(ser_list, ignore_index=True)
expected = pd.Series(
["a", "b", None, "a", "b", None], dtype=pd.StringDtype(*result_dtype)
)
tm.assert_series_equal(result, expected)

# order doesn't matter for result
result = pd.concat(ser_list[::1], ignore_index=True)
tm.assert_series_equal(result, expected)


def test_concat_with_object(string_dtype_arguments):
# _get_common_dtype cannot inspect values, so object dtype with strings still
# results in object dtype
result = pd.concat(
[
pd.Series(["a", "b", None], dtype=pd.StringDtype(*string_dtype_arguments)),
pd.Series(["a", "b", None], dtype=object),
]
)
assert result.dtype == np.dtype("object")


def test_concat_with_numpy(string_dtype_arguments):
# common type with a numpy string dtype always preserves the pandas string dtype
dtype = pd.StringDtype(*string_dtype_arguments)
assert find_common_type([dtype, np.dtype("U")]) == dtype
assert find_common_type([np.dtype("U"), dtype]) == dtype
assert find_common_type([dtype, np.dtype("U10")]) == dtype
assert find_common_type([np.dtype("U10"), dtype]) == dtype

# with any other numpy dtype -> object
assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object")
assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object")

if Version(np.__version__) >= Version("2"):
assert find_common_type([dtype, np.dtypes.StringDType()]) == dtype
assert find_common_type([np.dtypes.StringDType(), dtype]) == dtype