diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index c878fd2664dc4..99faad8aff986 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -576,6 +576,7 @@ Strings ^^^^^^^ - Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`) - Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`) +- Bug in :meth:`Index.str.cat` always casting result to object dtype (:issue:`56157`) - Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`) - Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 9fa6e9973291d..127aee24e094f 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -44,6 +44,7 @@ ) from pandas.core.dtypes.missing import isna +from pandas.core.arrays import ExtensionArray from pandas.core.base import NoNewAttributesMixin from pandas.core.construction import extract_array @@ -456,7 +457,7 @@ def _get_series_list(self, others): # in case of list-like `others`, all elements must be # either Series/Index/np.ndarray (1-dim)... if all( - isinstance(x, (ABCSeries, ABCIndex)) + isinstance(x, (ABCSeries, ABCIndex, ExtensionArray)) or (isinstance(x, np.ndarray) and x.ndim == 1) for x in others ): @@ -690,12 +691,15 @@ def cat( out: Index | Series if isinstance(self._orig, ABCIndex): # add dtype for case that result is all-NA + dtype = None + if isna(result).all(): + dtype = object - out = Index(result, dtype=object, name=self._orig.name) + out = Index(result, dtype=dtype, name=self._orig.name) else: # Series if isinstance(self._orig.dtype, CategoricalDtype): # We need to infer the new categories. - dtype = None + dtype = self._orig.dtype.categories.dtype # type: ignore[assignment] else: dtype = self._orig.dtype res_ser = Series( diff --git a/pandas/tests/strings/test_api.py b/pandas/tests/strings/test_api.py index 2914b22a52e94..fd2501835318d 100644 --- a/pandas/tests/strings/test_api.py +++ b/pandas/tests/strings/test_api.py @@ -2,6 +2,7 @@ import pytest from pandas import ( + CategoricalDtype, DataFrame, Index, MultiIndex, @@ -178,6 +179,7 @@ def test_api_for_categorical(any_string_method, any_string_dtype): s = Series(list("aabb"), dtype=any_string_dtype) s = s + " " + s c = s.astype("category") + c = c.astype(CategoricalDtype(c.dtype.categories.astype("object"))) assert isinstance(c.str, StringMethods) method_name, args, kwargs = any_string_method diff --git a/pandas/tests/strings/test_cat.py b/pandas/tests/strings/test_cat.py index 3e620b7664335..284932491a65e 100644 --- a/pandas/tests/strings/test_cat.py +++ b/pandas/tests/strings/test_cat.py @@ -3,6 +3,8 @@ import numpy as np import pytest +import pandas.util._test_decorators as td + from pandas import ( DataFrame, Index, @@ -10,6 +12,7 @@ Series, _testing as tm, concat, + option_context, ) @@ -26,45 +29,49 @@ def test_str_cat_name(index_or_series, other): assert result.name == "name" -def test_str_cat(index_or_series): - box = index_or_series - # test_cat above tests "str_cat" from ndarray; - # here testing "str.cat" from Series/Index to ndarray/list - s = box(["a", "a", "b", "b", "c", np.nan]) +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +def test_str_cat(index_or_series, infer_string): + with option_context("future.infer_string", infer_string): + box = index_or_series + # test_cat above tests "str_cat" from ndarray; + # here testing "str.cat" from Series/Index to ndarray/list + s = box(["a", "a", "b", "b", "c", np.nan]) - # single array - result = s.str.cat() - expected = "aabbc" - assert result == expected + # single array + result = s.str.cat() + expected = "aabbc" + assert result == expected - result = s.str.cat(na_rep="-") - expected = "aabbc-" - assert result == expected + result = s.str.cat(na_rep="-") + expected = "aabbc-" + assert result == expected - result = s.str.cat(sep="_", na_rep="NA") - expected = "a_a_b_b_c_NA" - assert result == expected + result = s.str.cat(sep="_", na_rep="NA") + expected = "a_a_b_b_c_NA" + assert result == expected - t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object) - expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"]) + t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object) + expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"]) - # Series/Index with array - result = s.str.cat(t, na_rep="-") - tm.assert_equal(result, expected) + # Series/Index with array + result = s.str.cat(t, na_rep="-") + tm.assert_equal(result, expected) - # Series/Index with list - result = s.str.cat(list(t), na_rep="-") - tm.assert_equal(result, expected) + # Series/Index with list + result = s.str.cat(list(t), na_rep="-") + tm.assert_equal(result, expected) - # errors for incorrect lengths - rgx = r"If `others` contains arrays or lists \(or other list-likes.*" - z = Series(["1", "2", "3"]) + # errors for incorrect lengths + rgx = r"If `others` contains arrays or lists \(or other list-likes.*" + z = Series(["1", "2", "3"]) - with pytest.raises(ValueError, match=rgx): - s.str.cat(z.values) + with pytest.raises(ValueError, match=rgx): + s.str.cat(z.values) - with pytest.raises(ValueError, match=rgx): - s.str.cat(list(z)) + with pytest.raises(ValueError, match=rgx): + s.str.cat(list(z)) def test_str_cat_raises_intuitive_error(index_or_series): @@ -78,39 +85,54 @@ def test_str_cat_raises_intuitive_error(index_or_series): s.str.cat(" ") +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) @pytest.mark.parametrize("sep", ["", None]) @pytest.mark.parametrize("dtype_target", ["object", "category"]) @pytest.mark.parametrize("dtype_caller", ["object", "category"]) -def test_str_cat_categorical(index_or_series, dtype_caller, dtype_target, sep): +def test_str_cat_categorical( + index_or_series, dtype_caller, dtype_target, sep, infer_string +): box = index_or_series - s = Index(["a", "a", "b", "a"], dtype=dtype_caller) - s = s if box == Index else Series(s, index=s) - t = Index(["b", "a", "b", "c"], dtype=dtype_target) - - expected = Index(["ab", "aa", "bb", "ac"]) - expected = expected if box == Index else Series(expected, index=s) + with option_context("future.infer_string", infer_string): + s = Index(["a", "a", "b", "a"], dtype=dtype_caller) + s = s if box == Index else Series(s, index=s) + t = Index(["b", "a", "b", "c"], dtype=dtype_target) - # Series/Index with unaligned Index -> t.values - result = s.str.cat(t.values, sep=sep) - tm.assert_equal(result, expected) - - # Series/Index with Series having matching Index - t = Series(t.values, index=s) - result = s.str.cat(t, sep=sep) - tm.assert_equal(result, expected) - - # Series/Index with Series.values - result = s.str.cat(t.values, sep=sep) - tm.assert_equal(result, expected) + expected = Index(["ab", "aa", "bb", "ac"]) + expected = ( + expected + if box == Index + else Series(expected, index=Index(s, dtype=dtype_caller)) + ) - # Series/Index with Series having different Index - t = Series(t.values, index=t.values) - expected = Index(["aa", "aa", "bb", "bb", "aa"]) - expected = expected if box == Index else Series(expected, index=expected.str[:1]) + # Series/Index with unaligned Index -> t.values + result = s.str.cat(t.values, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series having matching Index + t = Series(t.values, index=Index(s, dtype=dtype_caller)) + result = s.str.cat(t, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series.values + result = s.str.cat(t.values, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series having different Index + t = Series(t.values, index=t.values) + expected = Index(["aa", "aa", "bb", "bb", "aa"]) + dtype = object if dtype_caller == "object" else s.dtype.categories.dtype + expected = ( + expected + if box == Index + else Series(expected, index=Index(expected.str[:1], dtype=dtype)) + ) - result = s.str.cat(t, sep=sep) - tm.assert_equal(result, expected) + result = s.str.cat(t, sep=sep) + tm.assert_equal(result, expected) @pytest.mark.parametrize( @@ -321,8 +343,9 @@ def test_str_cat_all_na(index_or_series, index_or_series2): # all-NA target if box == Series: - expected = Series([np.nan] * 4, index=s.index, dtype=object) + expected = Series([np.nan] * 4, index=s.index, dtype=s.dtype) else: # box == Index + # TODO: Strimg option, this should return string dtype expected = Index([np.nan] * 4, dtype=object) result = s.str.cat(t, join="left") tm.assert_equal(result, expected)