Skip to content

Commit 54b47df

Browse files
[backport 2.3.x] BUG (string dtype): replace with non-string to fall back to object dtype (#60285) (#60292)
* BUG (string dtype): replace with non-string to fall back to object dtype (#60285) Co-authored-by: Matthew Roeschke <[email protected]> (cherry picked from commit 938832b) * updates for 2.3 * fix inplace modification for 2.3.x branch with python storage
1 parent 2054463 commit 54b47df

File tree

7 files changed

+76
-59
lines changed

7 files changed

+76
-59
lines changed

doc/source/whatsnew/v2.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ Conversion
107107
Strings
108108
^^^^^^^
109109
- Bug in :meth:`Series.rank` for :class:`StringDtype` with ``storage="pyarrow"`` incorrectly returning integer results in case of ``method="average"`` and raising an error if it would truncate results (:issue:`59768`)
110+
- Bug in :meth:`Series.replace` with :class:`StringDtype` when replacing with a non-string value was not upcasting to ``object`` dtype (:issue:`60282`)
110111
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`StringDtype` with ``storage="pyarrow"`` (:issue:`59628`)
111112
- Bug in ``ser.str.slice`` with negative ``step`` with :class:`ArrowDtype` and :class:`StringDtype` with ``storage="pyarrow"`` giving incorrect results (:issue:`59710`)
112113
- Bug in the ``center`` method on :class:`Series` and :class:`Index` object ``str`` accessors with pyarrow-backed dtype not matching the python behavior in corner cases with an odd number of fill characters (:issue:`54792`)
113-
-
114114

115115
Interval
116116
^^^^^^^^

pandas/core/arrays/string_.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -726,20 +726,9 @@ def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]:
726726

727727
return arr, self.dtype.na_value
728728

729-
def __setitem__(self, key, value) -> None:
730-
value = extract_array(value, extract_numpy=True)
731-
if isinstance(value, type(self)):
732-
# extract_array doesn't extract NumpyExtensionArray subclasses
733-
value = value._ndarray
734-
735-
key = check_array_indexer(self, key)
736-
scalar_key = lib.is_scalar(key)
737-
scalar_value = lib.is_scalar(value)
738-
if scalar_key and not scalar_value:
739-
raise ValueError("setting an array element with a sequence.")
740-
741-
# validate new items
742-
if scalar_value:
729+
def _maybe_convert_setitem_value(self, value):
730+
"""Maybe convert value to be pyarrow compatible."""
731+
if lib.is_scalar(value):
743732
if isna(value):
744733
value = self.dtype.na_value
745734
elif not isinstance(value, str):
@@ -749,8 +738,11 @@ def __setitem__(self, key, value) -> None:
749738
"instead."
750739
)
751740
else:
741+
value = extract_array(value, extract_numpy=True)
752742
if not is_array_like(value):
753743
value = np.asarray(value, dtype=object)
744+
elif isinstance(value.dtype, type(self.dtype)):
745+
return value
754746
else:
755747
# cast categories and friends to arrays to see if values are
756748
# compatible, compatibility with arrow backed strings
@@ -760,11 +752,26 @@ def __setitem__(self, key, value) -> None:
760752
"Invalid value for dtype 'str'. Value should be a "
761753
"string or missing value (or array of those)."
762754
)
755+
return value
763756

764-
mask = isna(value)
765-
if mask.any():
766-
value = value.copy()
767-
value[isna(value)] = self.dtype.na_value
757+
def __setitem__(self, key, value) -> None:
758+
value = self._maybe_convert_setitem_value(value)
759+
760+
key = check_array_indexer(self, key)
761+
scalar_key = lib.is_scalar(key)
762+
scalar_value = lib.is_scalar(value)
763+
if scalar_key and not scalar_value:
764+
raise ValueError("setting an array element with a sequence.")
765+
766+
if not scalar_value:
767+
if value.dtype == self.dtype:
768+
value = value._ndarray
769+
else:
770+
value = np.asarray(value)
771+
mask = isna(value)
772+
if mask.any():
773+
value = value.copy()
774+
value[isna(value)] = self.dtype.na_value
768775

769776
super().__setitem__(key, value)
770777

pandas/core/dtypes/cast.py

+7
Original file line numberDiff line numberDiff line change
@@ -1754,6 +1754,13 @@ def can_hold_element(arr: ArrayLike, element: Any) -> bool:
17541754
except (ValueError, TypeError):
17551755
return False
17561756

1757+
if dtype == "string":
1758+
try:
1759+
arr._maybe_convert_setitem_value(element) # type: ignore[union-attr]
1760+
return True
1761+
except (ValueError, TypeError):
1762+
return False
1763+
17571764
# This is technically incorrect, but maintains the behavior of
17581765
# ExtensionBlock._can_hold_element
17591766
return True

pandas/core/internals/blocks.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
ABCNumpyExtensionArray,
8585
ABCSeries,
8686
)
87+
from pandas.core.dtypes.inference import is_re
8788
from pandas.core.dtypes.missing import (
8889
is_valid_na_for_dtype,
8990
isna,
@@ -115,6 +116,7 @@
115116
PeriodArray,
116117
TimedeltaArray,
117118
)
119+
from pandas.core.arrays.string_ import StringDtype
118120
from pandas.core.base import PandasObject
119121
import pandas.core.common as com
120122
from pandas.core.computation import expressions
@@ -476,7 +478,9 @@ def split_and_operate(self, func, *args, **kwargs) -> list[Block]:
476478
# Up/Down-casting
477479

478480
@final
479-
def coerce_to_target_dtype(self, other, warn_on_upcast: bool = False) -> Block:
481+
def coerce_to_target_dtype(
482+
self, other, warn_on_upcast: bool = False, using_cow: bool = False
483+
) -> Block:
480484
"""
481485
coerce the current block to a dtype compat for other
482486
we will return a block, possibly object, and not raise
@@ -528,7 +532,14 @@ def coerce_to_target_dtype(self, other, warn_on_upcast: bool = False) -> Block:
528532
f"{self.values.dtype}. Please report a bug at "
529533
"https://github.com/pandas-dev/pandas/issues."
530534
)
531-
return self.astype(new_dtype, copy=False)
535+
copy = False
536+
if (
537+
not using_cow
538+
and isinstance(self.dtype, StringDtype)
539+
and self.dtype.storage == "python"
540+
):
541+
copy = True
542+
return self.astype(new_dtype, copy=copy, using_cow=using_cow)
532543

533544
@final
534545
def _maybe_downcast(
@@ -879,7 +890,7 @@ def replace(
879890
else:
880891
return [self] if inplace else [self.copy()]
881892

882-
elif self._can_hold_element(value):
893+
elif self._can_hold_element(value) or (self.dtype == "string" and is_re(value)):
883894
# TODO(CoW): Maybe split here as well into columns where mask has True
884895
# and rest?
885896
blk = self._maybe_copy(using_cow, inplace)
@@ -926,12 +937,13 @@ def replace(
926937
if value is None or value is NA:
927938
blk = self.astype(np.dtype(object))
928939
else:
929-
blk = self.coerce_to_target_dtype(value)
940+
blk = self.coerce_to_target_dtype(value, using_cow=using_cow)
930941
return blk.replace(
931942
to_replace=to_replace,
932943
value=value,
933944
inplace=True,
934945
mask=mask,
946+
using_cow=using_cow,
935947
)
936948

937949
else:
@@ -980,16 +992,26 @@ def _replace_regex(
980992
-------
981993
List[Block]
982994
"""
983-
if not self._can_hold_element(to_replace):
995+
if not is_re(to_replace) and not self._can_hold_element(to_replace):
984996
# i.e. only if self.is_object is True, but could in principle include a
985997
# String ExtensionBlock
986998
if using_cow:
987999
return [self.copy(deep=False)]
9881000
return [self] if inplace else [self.copy()]
9891001

990-
rx = re.compile(to_replace)
1002+
if is_re(to_replace) and self.dtype not in [object, "string"]:
1003+
# only object or string dtype can hold strings, and a regex object
1004+
# will only match strings
1005+
return [self.copy(deep=False)]
9911006

992-
block = self._maybe_copy(using_cow, inplace)
1007+
if not (
1008+
self._can_hold_element(value) or (self.dtype == "string" and is_re(value))
1009+
):
1010+
block = self.astype(np.dtype(object))
1011+
else:
1012+
block = self._maybe_copy(using_cow, inplace)
1013+
1014+
rx = re.compile(to_replace)
9931015

9941016
replace_regex(block.values, rx, value, mask)
9951017

@@ -1048,7 +1070,9 @@ def replace_list(
10481070

10491071
# Exclude anything that we know we won't contain
10501072
pairs = [
1051-
(x, y) for x, y in zip(src_list, dest_list) if self._can_hold_element(x)
1073+
(x, y)
1074+
for x, y in zip(src_list, dest_list)
1075+
if (self._can_hold_element(x) or (self.dtype == "string" and is_re(x)))
10521076
]
10531077
if not len(pairs):
10541078
if using_cow:

pandas/tests/frame/methods/test_replace.py

-9
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ def test_regex_replace_dict_nested_non_first_character(
297297
expected = DataFrame({"first": [".bc", "bc.", "c.b"]}, dtype=dtype)
298298
tm.assert_frame_equal(result, expected)
299299

300-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
301300
def test_regex_replace_dict_nested_gh4115(self):
302301
df = DataFrame({"Type": ["Q", "T", "Q", "Q", "T"], "tmp": 2})
303302
expected = DataFrame({"Type": [0, 1, 0, 0, 1], "tmp": 2})
@@ -556,7 +555,6 @@ def test_replace_series_dict(self):
556555
result = df.replace(s, df.mean())
557556
tm.assert_frame_equal(result, expected)
558557

559-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
560558
def test_replace_convert(self):
561559
# gh 3907
562560
df = DataFrame([["foo", "bar", "bah"], ["bar", "foo", "bah"]])
@@ -932,7 +930,6 @@ def test_replace_input_formats_listlike(self):
932930
with pytest.raises(ValueError, match=msg):
933931
df.replace(to_rep, values[1:])
934932

935-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
936933
def test_replace_input_formats_scalar(self):
937934
df = DataFrame(
938935
{"A": [np.nan, 0, np.inf], "B": [0, 2, 5], "C": ["", "asdf", "fd"]}
@@ -961,7 +958,6 @@ def test_replace_limit(self):
961958
# TODO
962959
pass
963960

964-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
965961
def test_replace_dict_no_regex(self):
966962
answer = Series(
967963
{
@@ -985,7 +981,6 @@ def test_replace_dict_no_regex(self):
985981
result = answer.replace(weights)
986982
tm.assert_series_equal(result, expected)
987983

988-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
989984
def test_replace_series_no_regex(self):
990985
answer = Series(
991986
{
@@ -1104,7 +1099,6 @@ def test_replace_swapping_bug(self, using_infer_string):
11041099
expect = DataFrame({"a": ["Y", "N", "Y"]})
11051100
tm.assert_frame_equal(res, expect)
11061101

1107-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
11081102
def test_replace_period(self):
11091103
d = {
11101104
"fname": {
@@ -1141,7 +1135,6 @@ def test_replace_period(self):
11411135
result = df.replace(d)
11421136
tm.assert_frame_equal(result, expected)
11431137

1144-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
11451138
def test_replace_datetime(self):
11461139
d = {
11471140
"fname": {
@@ -1367,7 +1360,6 @@ def test_replace_commutative(self, df, to_replace, exp):
13671360
result = df.replace(to_replace)
13681361
tm.assert_frame_equal(result, expected)
13691362

1370-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
13711363
@pytest.mark.parametrize(
13721364
"replacer",
13731365
[
@@ -1644,7 +1636,6 @@ def test_regex_replace_scalar(
16441636
expected.loc[expected["a"] == ".", "a"] = expected_replace_val
16451637
tm.assert_frame_equal(result, expected)
16461638

1647-
@pytest.mark.xfail(using_string_dtype(), reason="can't set float into string")
16481639
@pytest.mark.parametrize("regex", [False, True])
16491640
def test_replace_regex_dtype_frame(self, regex):
16501641
# GH-48644

pandas/tests/series/indexing/test_setitem.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -886,24 +886,16 @@ def test_index_where(self, obj, key, expected, warn, val):
886886
mask = np.zeros(obj.shape, dtype=bool)
887887
mask[key] = True
888888

889-
if obj.dtype == "string" and not (isinstance(val, str) or isna(val)):
890-
with pytest.raises(TypeError, match="Invalid value"):
891-
Index(obj, dtype=obj.dtype).where(~mask, val)
892-
else:
893-
res = Index(obj, dtype=obj.dtype).where(~mask, val)
894-
expected_idx = Index(expected, dtype=expected.dtype)
895-
tm.assert_index_equal(res, expected_idx)
889+
res = Index(obj, dtype=obj.dtype).where(~mask, val)
890+
expected_idx = Index(expected, dtype=expected.dtype)
891+
tm.assert_index_equal(res, expected_idx)
896892

897893
def test_index_putmask(self, obj, key, expected, warn, val):
898894
mask = np.zeros(obj.shape, dtype=bool)
899895
mask[key] = True
900896

901-
if obj.dtype == "string" and not (isinstance(val, str) or isna(val)):
902-
with pytest.raises(TypeError, match="Invalid value"):
903-
Index(obj, dtype=obj.dtype).putmask(mask, val)
904-
else:
905-
res = Index(obj, dtype=obj.dtype).putmask(mask, val)
906-
tm.assert_index_equal(res, Index(expected, dtype=expected.dtype))
897+
res = Index(obj, dtype=obj.dtype).putmask(mask, val)
898+
tm.assert_index_equal(res, Index(expected, dtype=expected.dtype))
907899

908900

909901
@pytest.mark.parametrize(

pandas/tests/series/methods/test_replace.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -403,10 +403,6 @@ def test_replace_categorical(self, categorical, numeric, using_infer_string):
403403
ser = pd.Series(categorical)
404404
msg = "Downcasting behavior in `replace`"
405405
msg = "with CategoricalDtype is deprecated"
406-
if using_infer_string:
407-
with pytest.raises(TypeError, match="Invalid value"):
408-
ser.replace({"A": 1, "B": 2})
409-
return
410406
with tm.assert_produces_warning(FutureWarning, match=msg):
411407
result = ser.replace({"A": 1, "B": 2})
412408
expected = pd.Series(numeric).astype("category")
@@ -745,13 +741,13 @@ def test_replace_regex_dtype_series(self, regex):
745741
tm.assert_series_equal(result, expected)
746742

747743
@pytest.mark.parametrize("regex", [False, True])
748-
def test_replace_regex_dtype_series_string(self, regex, using_infer_string):
749-
if not using_infer_string:
750-
# then this is object dtype which is already tested above
751-
return
744+
def test_replace_regex_dtype_series_string(self, regex):
752745
series = pd.Series(["0"], dtype="str")
753-
with pytest.raises(TypeError, match="Invalid value"):
754-
series.replace(to_replace="0", value=1, regex=regex)
746+
expected = pd.Series([1], dtype="int64")
747+
msg = "Downcasting behavior in `replace`"
748+
with tm.assert_produces_warning(FutureWarning, match=msg):
749+
result = series.replace(to_replace="0", value=1, regex=regex)
750+
tm.assert_series_equal(result, expected)
755751

756752
def test_replace_different_int_types(self, any_int_numpy_dtype):
757753
# GH#45311

0 commit comments

Comments
 (0)