Skip to content

Commit 4a16b44

Browse files
String dtype: implement _get_common_dtype (#59682)
* String dtype: implement _get_common_dtype * add specific tests * try fix typing * try fix typing * suppress typing error * support numpy 2.0 string * fix typo
1 parent 8cd761a commit 4a16b44

File tree

3 files changed

+103
-5
lines changed

3 files changed

+103
-5
lines changed

pandas/core/arrays/string_.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def __init__(
171171
# a consistent NaN value (and we can use `dtype.na_value is np.nan`)
172172
na_value = np.nan
173173
elif na_value is not libmissing.NA:
174-
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
174+
raise ValueError(f"'na_value' must be np.nan or pd.NA, got {na_value}")
175175

176-
self.storage = storage
176+
self.storage = cast(str, storage)
177177
self._na_value = na_value
178178

179179
def __repr__(self) -> str:
@@ -284,6 +284,34 @@ def construct_array_type( # type: ignore[override]
284284
else:
285285
return ArrowStringArrayNumpySemantics
286286

287+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
288+
storages = set()
289+
na_values = set()
290+
291+
for dtype in dtypes:
292+
if isinstance(dtype, StringDtype):
293+
storages.add(dtype.storage)
294+
na_values.add(dtype.na_value)
295+
elif isinstance(dtype, np.dtype) and dtype.kind in ("U", "T"):
296+
continue
297+
else:
298+
return None
299+
300+
if len(storages) == 2:
301+
# if both python and pyarrow storage -> priority to pyarrow
302+
storage = "pyarrow"
303+
else:
304+
storage = next(iter(storages)) # type: ignore[assignment]
305+
306+
na_value: libmissing.NAType | float
307+
if len(na_values) == 2:
308+
# if both NaN and NA -> priority to NA
309+
na_value = libmissing.NA
310+
else:
311+
na_value = next(iter(na_values))
312+
313+
return StringDtype(storage=storage, na_value=na_value)
314+
287315
def __from_arrow__(
288316
self, array: pyarrow.Array | pyarrow.ChunkedArray
289317
) -> BaseStringArray:

pandas/tests/arrays/categorical/test_api.py

-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas.compat import PY311
97

108
from pandas import (
@@ -151,7 +149,6 @@ def test_reorder_categories_raises(self, new_categories):
151149
with pytest.raises(ValueError, match=msg):
152150
cat.reorder_categories(new_categories)
153151

154-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
155152
def test_add_categories(self):
156153
cat = Categorical(["a", "b", "c", "a"], ordered=True)
157154
old = cat.copy()
+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas.compat import HAS_PYARROW
5+
6+
from pandas.core.dtypes.cast import find_common_type
7+
8+
import pandas as pd
9+
import pandas._testing as tm
10+
from pandas.util.version import Version
11+
12+
13+
@pytest.mark.parametrize(
14+
"to_concat_dtypes, result_dtype",
15+
[
16+
# same types
17+
([("pyarrow", pd.NA), ("pyarrow", pd.NA)], ("pyarrow", pd.NA)),
18+
([("pyarrow", np.nan), ("pyarrow", np.nan)], ("pyarrow", np.nan)),
19+
([("python", pd.NA), ("python", pd.NA)], ("python", pd.NA)),
20+
([("python", np.nan), ("python", np.nan)], ("python", np.nan)),
21+
# pyarrow preference
22+
([("pyarrow", pd.NA), ("python", pd.NA)], ("pyarrow", pd.NA)),
23+
# NA preference
24+
([("python", pd.NA), ("python", np.nan)], ("python", pd.NA)),
25+
],
26+
)
27+
def test_concat_series(request, to_concat_dtypes, result_dtype):
28+
if any(storage == "pyarrow" for storage, _ in to_concat_dtypes) and not HAS_PYARROW:
29+
pytest.skip("Could not import 'pyarrow'")
30+
31+
ser_list = [
32+
pd.Series(["a", "b", None], dtype=pd.StringDtype(storage, na_value))
33+
for storage, na_value in to_concat_dtypes
34+
]
35+
36+
result = pd.concat(ser_list, ignore_index=True)
37+
expected = pd.Series(
38+
["a", "b", None, "a", "b", None], dtype=pd.StringDtype(*result_dtype)
39+
)
40+
tm.assert_series_equal(result, expected)
41+
42+
# order doesn't matter for result
43+
result = pd.concat(ser_list[::1], ignore_index=True)
44+
tm.assert_series_equal(result, expected)
45+
46+
47+
def test_concat_with_object(string_dtype_arguments):
48+
# _get_common_dtype cannot inspect values, so object dtype with strings still
49+
# results in object dtype
50+
result = pd.concat(
51+
[
52+
pd.Series(["a", "b", None], dtype=pd.StringDtype(*string_dtype_arguments)),
53+
pd.Series(["a", "b", None], dtype=object),
54+
]
55+
)
56+
assert result.dtype == np.dtype("object")
57+
58+
59+
def test_concat_with_numpy(string_dtype_arguments):
60+
# common type with a numpy string dtype always preserves the pandas string dtype
61+
dtype = pd.StringDtype(*string_dtype_arguments)
62+
assert find_common_type([dtype, np.dtype("U")]) == dtype
63+
assert find_common_type([np.dtype("U"), dtype]) == dtype
64+
assert find_common_type([dtype, np.dtype("U10")]) == dtype
65+
assert find_common_type([np.dtype("U10"), dtype]) == dtype
66+
67+
# with any other numpy dtype -> object
68+
assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object")
69+
assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object")
70+
71+
if Version(np.__version__) >= Version("2"):
72+
assert find_common_type([dtype, np.dtypes.StringDType()]) == dtype
73+
assert find_common_type([np.dtypes.StringDType(), dtype]) == dtype

0 commit comments

Comments
 (0)