Skip to content

Commit ccb90e3

Browse files
String dtype: implement _get_common_dtype (pandas-dev#59682)
* String dtype: implement _get_common_dtype * add specific tests * try fix typing * try fix typing * suppress typing error * support numpy 2.0 string * fix typo
1 parent 44325c1 commit ccb90e3

File tree

3 files changed

+103
-5
lines changed

3 files changed

+103
-5
lines changed

pandas/core/arrays/string_.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ def __init__(
167167
# a consistent NaN value (and we can use `dtype.na_value is np.nan`)
168168
na_value = np.nan
169169
elif na_value is not libmissing.NA:
170-
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
170+
raise ValueError(f"'na_value' must be np.nan or pd.NA, got {na_value}")
171171

172-
self.storage = storage
172+
self.storage = cast(str, storage)
173173
self._na_value = na_value
174174

175175
def __repr__(self) -> str:
@@ -280,6 +280,34 @@ def construct_array_type( # type: ignore[override]
280280
else:
281281
return ArrowStringArrayNumpySemantics
282282

283+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
284+
storages = set()
285+
na_values = set()
286+
287+
for dtype in dtypes:
288+
if isinstance(dtype, StringDtype):
289+
storages.add(dtype.storage)
290+
na_values.add(dtype.na_value)
291+
elif isinstance(dtype, np.dtype) and dtype.kind in ("U", "T"):
292+
continue
293+
else:
294+
return None
295+
296+
if len(storages) == 2:
297+
# if both python and pyarrow storage -> priority to pyarrow
298+
storage = "pyarrow"
299+
else:
300+
storage = next(iter(storages)) # type: ignore[assignment]
301+
302+
na_value: libmissing.NAType | float
303+
if len(na_values) == 2:
304+
# if both NaN and NA -> priority to NA
305+
na_value = libmissing.NA
306+
else:
307+
na_value = next(iter(na_values))
308+
309+
return StringDtype(storage=storage, na_value=na_value)
310+
283311
def __from_arrow__(
284312
self, array: pyarrow.Array | pyarrow.ChunkedArray
285313
) -> BaseStringArray:

pandas/tests/arrays/categorical/test_api.py

-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas.compat import PY311
97

108
from pandas import (
@@ -158,7 +156,6 @@ def test_reorder_categories_raises(self, new_categories):
158156
with pytest.raises(ValueError, match=msg):
159157
cat.reorder_categories(new_categories)
160158

161-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
162159
def test_add_categories(self):
163160
cat = Categorical(["a", "b", "c", "a"], ordered=True)
164161
old = cat.copy()
+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas.compat import HAS_PYARROW
5+
6+
from pandas.core.dtypes.cast import find_common_type
7+
8+
import pandas as pd
9+
import pandas._testing as tm
10+
from pandas.util.version import Version
11+
12+
13+
@pytest.mark.parametrize(
14+
"to_concat_dtypes, result_dtype",
15+
[
16+
# same types
17+
([("pyarrow", pd.NA), ("pyarrow", pd.NA)], ("pyarrow", pd.NA)),
18+
([("pyarrow", np.nan), ("pyarrow", np.nan)], ("pyarrow", np.nan)),
19+
([("python", pd.NA), ("python", pd.NA)], ("python", pd.NA)),
20+
([("python", np.nan), ("python", np.nan)], ("python", np.nan)),
21+
# pyarrow preference
22+
([("pyarrow", pd.NA), ("python", pd.NA)], ("pyarrow", pd.NA)),
23+
# NA preference
24+
([("python", pd.NA), ("python", np.nan)], ("python", pd.NA)),
25+
],
26+
)
27+
def test_concat_series(request, to_concat_dtypes, result_dtype):
28+
if any(storage == "pyarrow" for storage, _ in to_concat_dtypes) and not HAS_PYARROW:
29+
pytest.skip("Could not import 'pyarrow'")
30+
31+
ser_list = [
32+
pd.Series(["a", "b", None], dtype=pd.StringDtype(storage, na_value))
33+
for storage, na_value in to_concat_dtypes
34+
]
35+
36+
result = pd.concat(ser_list, ignore_index=True)
37+
expected = pd.Series(
38+
["a", "b", None, "a", "b", None], dtype=pd.StringDtype(*result_dtype)
39+
)
40+
tm.assert_series_equal(result, expected)
41+
42+
# order doesn't matter for result
43+
result = pd.concat(ser_list[::1], ignore_index=True)
44+
tm.assert_series_equal(result, expected)
45+
46+
47+
def test_concat_with_object(string_dtype_arguments):
48+
# _get_common_dtype cannot inspect values, so object dtype with strings still
49+
# results in object dtype
50+
result = pd.concat(
51+
[
52+
pd.Series(["a", "b", None], dtype=pd.StringDtype(*string_dtype_arguments)),
53+
pd.Series(["a", "b", None], dtype=object),
54+
]
55+
)
56+
assert result.dtype == np.dtype("object")
57+
58+
59+
def test_concat_with_numpy(string_dtype_arguments):
60+
# common type with a numpy string dtype always preserves the pandas string dtype
61+
dtype = pd.StringDtype(*string_dtype_arguments)
62+
assert find_common_type([dtype, np.dtype("U")]) == dtype
63+
assert find_common_type([np.dtype("U"), dtype]) == dtype
64+
assert find_common_type([dtype, np.dtype("U10")]) == dtype
65+
assert find_common_type([np.dtype("U10"), dtype]) == dtype
66+
67+
# with any other numpy dtype -> object
68+
assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object")
69+
assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object")
70+
71+
if Version(np.__version__) >= Version("2"):
72+
assert find_common_type([dtype, np.dtypes.StringDType()]) == dtype
73+
assert find_common_type([np.dtypes.StringDType(), dtype]) == dtype

0 commit comments

Comments
 (0)