Skip to content

Commit a3626f2

Browse files
authored
BUG: Index.str.cat casting result always to object (#56157)
* BUG: Index.str.cat casting result always to object * Update accessor.py * Fix further bugs * Fix * Update accessor.py * Update v2.1.4.rst * Update v2.2.0.rst
1 parent 45361a4 commit a3626f2

File tree

4 files changed

+89
-59
lines changed

4 files changed

+89
-59
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ Strings
576576
^^^^^^^
577577
- Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`)
578578
- Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`)
579+
- Bug in :meth:`Index.str.cat` always casting result to object dtype (:issue:`56157`)
579580
- Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`)
580581
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`)
581582
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)

pandas/core/strings/accessor.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from pandas.core.dtypes.missing import isna
4646

47+
from pandas.core.arrays import ExtensionArray
4748
from pandas.core.base import NoNewAttributesMixin
4849
from pandas.core.construction import extract_array
4950

@@ -456,7 +457,7 @@ def _get_series_list(self, others):
456457
# in case of list-like `others`, all elements must be
457458
# either Series/Index/np.ndarray (1-dim)...
458459
if all(
459-
isinstance(x, (ABCSeries, ABCIndex))
460+
isinstance(x, (ABCSeries, ABCIndex, ExtensionArray))
460461
or (isinstance(x, np.ndarray) and x.ndim == 1)
461462
for x in others
462463
):
@@ -690,12 +691,15 @@ def cat(
690691
out: Index | Series
691692
if isinstance(self._orig, ABCIndex):
692693
# add dtype for case that result is all-NA
694+
dtype = None
695+
if isna(result).all():
696+
dtype = object
693697

694-
out = Index(result, dtype=object, name=self._orig.name)
698+
out = Index(result, dtype=dtype, name=self._orig.name)
695699
else: # Series
696700
if isinstance(self._orig.dtype, CategoricalDtype):
697701
# We need to infer the new categories.
698-
dtype = None
702+
dtype = self._orig.dtype.categories.dtype # type: ignore[assignment]
699703
else:
700704
dtype = self._orig.dtype
701705
res_ser = Series(

pandas/tests/strings/test_api.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from pandas import (
5+
CategoricalDtype,
56
DataFrame,
67
Index,
78
MultiIndex,
@@ -178,6 +179,7 @@ def test_api_for_categorical(any_string_method, any_string_dtype):
178179
s = Series(list("aabb"), dtype=any_string_dtype)
179180
s = s + " " + s
180181
c = s.astype("category")
182+
c = c.astype(CategoricalDtype(c.dtype.categories.astype("object")))
181183
assert isinstance(c.str, StringMethods)
182184

183185
method_name, args, kwargs = any_string_method

pandas/tests/strings/test_cat.py

+79-56
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import numpy as np
44
import pytest
55

6+
import pandas.util._test_decorators as td
7+
68
from pandas import (
79
DataFrame,
810
Index,
911
MultiIndex,
1012
Series,
1113
_testing as tm,
1214
concat,
15+
option_context,
1316
)
1417

1518

@@ -26,45 +29,49 @@ def test_str_cat_name(index_or_series, other):
2629
assert result.name == "name"
2730

2831

29-
def test_str_cat(index_or_series):
30-
box = index_or_series
31-
# test_cat above tests "str_cat" from ndarray;
32-
# here testing "str.cat" from Series/Index to ndarray/list
33-
s = box(["a", "a", "b", "b", "c", np.nan])
32+
@pytest.mark.parametrize(
33+
"infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))]
34+
)
35+
def test_str_cat(index_or_series, infer_string):
36+
with option_context("future.infer_string", infer_string):
37+
box = index_or_series
38+
# test_cat above tests "str_cat" from ndarray;
39+
# here testing "str.cat" from Series/Index to ndarray/list
40+
s = box(["a", "a", "b", "b", "c", np.nan])
3441

35-
# single array
36-
result = s.str.cat()
37-
expected = "aabbc"
38-
assert result == expected
42+
# single array
43+
result = s.str.cat()
44+
expected = "aabbc"
45+
assert result == expected
3946

40-
result = s.str.cat(na_rep="-")
41-
expected = "aabbc-"
42-
assert result == expected
47+
result = s.str.cat(na_rep="-")
48+
expected = "aabbc-"
49+
assert result == expected
4350

44-
result = s.str.cat(sep="_", na_rep="NA")
45-
expected = "a_a_b_b_c_NA"
46-
assert result == expected
51+
result = s.str.cat(sep="_", na_rep="NA")
52+
expected = "a_a_b_b_c_NA"
53+
assert result == expected
4754

48-
t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object)
49-
expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"])
55+
t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object)
56+
expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"])
5057

51-
# Series/Index with array
52-
result = s.str.cat(t, na_rep="-")
53-
tm.assert_equal(result, expected)
58+
# Series/Index with array
59+
result = s.str.cat(t, na_rep="-")
60+
tm.assert_equal(result, expected)
5461

55-
# Series/Index with list
56-
result = s.str.cat(list(t), na_rep="-")
57-
tm.assert_equal(result, expected)
62+
# Series/Index with list
63+
result = s.str.cat(list(t), na_rep="-")
64+
tm.assert_equal(result, expected)
5865

59-
# errors for incorrect lengths
60-
rgx = r"If `others` contains arrays or lists \(or other list-likes.*"
61-
z = Series(["1", "2", "3"])
66+
# errors for incorrect lengths
67+
rgx = r"If `others` contains arrays or lists \(or other list-likes.*"
68+
z = Series(["1", "2", "3"])
6269

63-
with pytest.raises(ValueError, match=rgx):
64-
s.str.cat(z.values)
70+
with pytest.raises(ValueError, match=rgx):
71+
s.str.cat(z.values)
6572

66-
with pytest.raises(ValueError, match=rgx):
67-
s.str.cat(list(z))
73+
with pytest.raises(ValueError, match=rgx):
74+
s.str.cat(list(z))
6875

6976

7077
def test_str_cat_raises_intuitive_error(index_or_series):
@@ -78,39 +85,54 @@ def test_str_cat_raises_intuitive_error(index_or_series):
7885
s.str.cat(" ")
7986

8087

88+
@pytest.mark.parametrize(
89+
"infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))]
90+
)
8191
@pytest.mark.parametrize("sep", ["", None])
8292
@pytest.mark.parametrize("dtype_target", ["object", "category"])
8393
@pytest.mark.parametrize("dtype_caller", ["object", "category"])
84-
def test_str_cat_categorical(index_or_series, dtype_caller, dtype_target, sep):
94+
def test_str_cat_categorical(
95+
index_or_series, dtype_caller, dtype_target, sep, infer_string
96+
):
8597
box = index_or_series
8698

87-
s = Index(["a", "a", "b", "a"], dtype=dtype_caller)
88-
s = s if box == Index else Series(s, index=s)
89-
t = Index(["b", "a", "b", "c"], dtype=dtype_target)
90-
91-
expected = Index(["ab", "aa", "bb", "ac"])
92-
expected = expected if box == Index else Series(expected, index=s)
99+
with option_context("future.infer_string", infer_string):
100+
s = Index(["a", "a", "b", "a"], dtype=dtype_caller)
101+
s = s if box == Index else Series(s, index=s)
102+
t = Index(["b", "a", "b", "c"], dtype=dtype_target)
93103

94-
# Series/Index with unaligned Index -> t.values
95-
result = s.str.cat(t.values, sep=sep)
96-
tm.assert_equal(result, expected)
97-
98-
# Series/Index with Series having matching Index
99-
t = Series(t.values, index=s)
100-
result = s.str.cat(t, sep=sep)
101-
tm.assert_equal(result, expected)
102-
103-
# Series/Index with Series.values
104-
result = s.str.cat(t.values, sep=sep)
105-
tm.assert_equal(result, expected)
104+
expected = Index(["ab", "aa", "bb", "ac"])
105+
expected = (
106+
expected
107+
if box == Index
108+
else Series(expected, index=Index(s, dtype=dtype_caller))
109+
)
106110

107-
# Series/Index with Series having different Index
108-
t = Series(t.values, index=t.values)
109-
expected = Index(["aa", "aa", "bb", "bb", "aa"])
110-
expected = expected if box == Index else Series(expected, index=expected.str[:1])
111+
# Series/Index with unaligned Index -> t.values
112+
result = s.str.cat(t.values, sep=sep)
113+
tm.assert_equal(result, expected)
114+
115+
# Series/Index with Series having matching Index
116+
t = Series(t.values, index=Index(s, dtype=dtype_caller))
117+
result = s.str.cat(t, sep=sep)
118+
tm.assert_equal(result, expected)
119+
120+
# Series/Index with Series.values
121+
result = s.str.cat(t.values, sep=sep)
122+
tm.assert_equal(result, expected)
123+
124+
# Series/Index with Series having different Index
125+
t = Series(t.values, index=t.values)
126+
expected = Index(["aa", "aa", "bb", "bb", "aa"])
127+
dtype = object if dtype_caller == "object" else s.dtype.categories.dtype
128+
expected = (
129+
expected
130+
if box == Index
131+
else Series(expected, index=Index(expected.str[:1], dtype=dtype))
132+
)
111133

112-
result = s.str.cat(t, sep=sep)
113-
tm.assert_equal(result, expected)
134+
result = s.str.cat(t, sep=sep)
135+
tm.assert_equal(result, expected)
114136

115137

116138
@pytest.mark.parametrize(
@@ -321,8 +343,9 @@ def test_str_cat_all_na(index_or_series, index_or_series2):
321343

322344
# all-NA target
323345
if box == Series:
324-
expected = Series([np.nan] * 4, index=s.index, dtype=object)
346+
expected = Series([np.nan] * 4, index=s.index, dtype=s.dtype)
325347
else: # box == Index
348+
# TODO: Strimg option, this should return string dtype
326349
expected = Index([np.nan] * 4, dtype=object)
327350
result = s.str.cat(t, join="left")
328351
tm.assert_equal(result, expected)

0 commit comments

Comments
 (0)