Skip to content

Commit 9cd1c6f

Browse files
authored
BUG: nullable dtypes not preserved in Series.replace (pandas-dev#44940)
1 parent 9e1a741 commit 9cd1c6f

File tree

7 files changed

+157
-67
lines changed

7 files changed

+157
-67
lines changed

doc/source/whatsnew/v1.4.0.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ Other Deprecations
529529
- Deprecated silent dropping of columns that raised a ``TypeError`` in :class:`Series.transform` and :class:`DataFrame.transform` when used with a dictionary (:issue:`43740`)
530530
- Deprecated silent dropping of columns that raised a ``TypeError``, ``DataError``, and some cases of ``ValueError`` in :meth:`Series.aggregate`, :meth:`DataFrame.aggregate`, :meth:`Series.groupby.aggregate`, and :meth:`DataFrame.groupby.aggregate` when used with a list (:issue:`43740`)
531531
- Deprecated casting behavior when setting timezone-aware value(s) into a timezone-aware :class:`Series` or :class:`DataFrame` column when the timezones do not match. Previously this cast to object dtype. In a future version, the values being inserted will be converted to the series or column's existing timezone (:issue:`37605`)
532-
- Deprecated casting behavior when passing an item with mismatched-timezone to :meth:`DatetimeIndex.insert`, :meth:`DatetimeIndex.putmask`, :meth:`DatetimeIndex.where` :meth:`DatetimeIndex.fillna`, :meth:`Series.mask`, :meth:`Series.where`, :meth:`Series.fillna`, :meth:`Series.shift`, :meth:`Series.replace`, :meth:`Series.reindex` (and :class:`DataFrame` column analogues). In the past this has cast to object dtype. In a future version, these will cast the passed item to the index or series's timezone (:issue:`37605`)
532+
- Deprecated casting behavior when passing an item with mismatched-timezone to :meth:`DatetimeIndex.insert`, :meth:`DatetimeIndex.putmask`, :meth:`DatetimeIndex.where` :meth:`DatetimeIndex.fillna`, :meth:`Series.mask`, :meth:`Series.where`, :meth:`Series.fillna`, :meth:`Series.shift`, :meth:`Series.replace`, :meth:`Series.reindex` (and :class:`DataFrame` column analogues). In the past this has cast to object dtype. In a future version, these will cast the passed item to the index or series's timezone (:issue:`37605`,:issue:`44940`)
533533
- Deprecated the 'errors' keyword argument in :meth:`Series.where`, :meth:`DataFrame.where`, :meth:`Series.mask`, and meth:`DataFrame.mask`; in a future version the argument will be removed (:issue:`44294`)
534534
- Deprecated the ``prefix`` keyword argument in :func:`read_csv` and :func:`read_table`, in a future version the argument will be removed (:issue:`43396`)
535535
- Deprecated :meth:`PeriodIndex.astype` to ``datetime64[ns]`` or ``DatetimeTZDtype``, use ``obj.to_timestamp(how).tz_localize(dtype.tz)`` instead (:issue:`44398`)
@@ -843,7 +843,7 @@ ExtensionArray
843843
- Bug in :func:`array` incorrectly raising when passed a ``ndarray`` with ``float16`` dtype (:issue:`44715`)
844844
- Bug in calling ``np.sqrt`` on :class:`BooleanArray` returning a malformed :class:`FloatingArray` (:issue:`44715`)
845845
- Bug in :meth:`Series.where` with ``ExtensionDtype`` when ``other`` is a NA scalar incompatible with the series dtype (e.g. ``NaT`` with a numeric dtype) incorrectly casting to a compatible NA value (:issue:`44697`)
846-
-
846+
- Fixed bug in :meth:`Series.replace` with ``FloatDtype``, ``string[python]``, or ``string[pyarrow]`` dtype not being preserved when possible (:issue:`33484`)
847847

848848
Styler
849849
^^^^^^

pandas/core/array_algos/replace.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def _check_comparison_types(
8080
f"Cannot compare types {repr(type_names[0])} and {repr(type_names[1])}"
8181
)
8282

83-
if not regex:
83+
if not regex or not should_use_regex(regex, b):
84+
# TODO: should use missing.mask_missing?
8485
op = lambda x: operator.eq(x, b)
8586
else:
8687
op = np.vectorize(

pandas/core/internals/blocks.py

+17-30
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,8 @@ def replace(
640640
to_replace,
641641
value,
642642
inplace: bool = False,
643+
# mask may be pre-computed if we're called from replace_list
644+
mask: npt.NDArray[np.bool_] | None = None,
643645
) -> list[Block]:
644646
"""
645647
replace the to_replace value with value, possible to create new
@@ -665,7 +667,8 @@ def replace(
665667
# replace_list instead of replace.
666668
return [self] if inplace else [self.copy()]
667669

668-
mask = missing.mask_missing(values, to_replace)
670+
if mask is None:
671+
mask = missing.mask_missing(values, to_replace)
669672
if not mask.any():
670673
# Note: we get here with test_replace_extension_other incorrectly
671674
# bc _can_hold_element is incorrect.
@@ -683,6 +686,7 @@ def replace(
683686
to_replace=to_replace,
684687
value=value,
685688
inplace=True,
689+
mask=mask,
686690
)
687691

688692
else:
@@ -746,16 +750,6 @@ def replace_list(
746750
"""
747751
values = self.values
748752

749-
# TODO: dont special-case Categorical
750-
if (
751-
isinstance(values, Categorical)
752-
and len(algos.unique(dest_list)) == 1
753-
and not regex
754-
):
755-
# We likely got here by tiling value inside NDFrame.replace,
756-
# so un-tile here
757-
return self.replace(src_list, dest_list[0], inplace)
758-
759753
# Exclude anything that we know we won't contain
760754
pairs = [
761755
(x, y) for x, y in zip(src_list, dest_list) if self._can_hold_element(x)
@@ -844,25 +838,18 @@ def _replace_coerce(
844838
-------
845839
List[Block]
846840
"""
847-
if mask.any():
848-
if not regex:
849-
nb = self.coerce_to_target_dtype(value)
850-
if nb is self and not inplace:
851-
nb = nb.copy()
852-
putmask_inplace(nb.values, mask, value)
853-
return [nb]
854-
else:
855-
regex = should_use_regex(regex, to_replace)
856-
if regex:
857-
return self._replace_regex(
858-
to_replace,
859-
value,
860-
inplace=inplace,
861-
convert=False,
862-
mask=mask,
863-
)
864-
return self.replace(to_replace, value, inplace=inplace)
865-
return [self]
841+
if should_use_regex(regex, to_replace):
842+
return self._replace_regex(
843+
to_replace,
844+
value,
845+
inplace=inplace,
846+
convert=False,
847+
mask=mask,
848+
)
849+
else:
850+
return self.replace(
851+
to_replace=to_replace, value=value, inplace=inplace, mask=mask
852+
)
866853

867854
# ---------------------------------------------------------------------
868855

pandas/tests/arrays/categorical/test_replace.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import pytest
32

43
import pandas as pd
@@ -20,18 +19,15 @@
2019
([1, 2], 4, [4, 4, 3], False),
2120
((1, 2, 4), 5, [5, 5, 3], False),
2221
((5, 6), 2, [1, 2, 3], False),
23-
# many-to-many, handled outside of Categorical and results in separate dtype
24-
# except for cases with only 1 unique entry in `value`
25-
([1], [2], [2, 2, 3], True),
26-
([1, 4], [5, 2], [5, 2, 3], True),
22+
([1], [2], [2, 2, 3], False),
23+
([1, 4], [5, 2], [5, 2, 3], False),
2724
# check_categorical sorts categories, which crashes on mixed dtypes
2825
(3, "4", [1, 2, "4"], False),
2926
([1, 2, "3"], "5", ["5", "5", 3], True),
3027
],
3128
)
3229
def test_replace_categorical_series(to_replace, value, expected, flip_categories):
3330
# GH 31720
34-
stays_categorical = not isinstance(value, list) or len(pd.unique(value)) == 1
3531

3632
ser = pd.Series([1, 2, 3], dtype="category")
3733
result = ser.replace(to_replace, value)
@@ -41,10 +37,6 @@ def test_replace_categorical_series(to_replace, value, expected, flip_categories
4137
if flip_categories:
4238
expected = expected.cat.set_categories(expected.cat.categories[::-1])
4339

44-
if not stays_categorical:
45-
# the replace call loses categorical dtype
46-
expected = pd.Series(np.asarray(expected))
47-
4840
tm.assert_series_equal(expected, result, check_category_order=False)
4941
tm.assert_series_equal(expected, ser, check_category_order=False)
5042

pandas/tests/frame/methods/test_replace.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,14 @@ def test_replace_mixed3(self):
624624
expected.iloc[1, 1] = m[1]
625625
tm.assert_frame_equal(result, expected)
626626

627+
def test_replace_nullable_int_with_string_doesnt_cast(self):
628+
# GH#25438 don't cast df['a'] to float64
629+
df = DataFrame({"a": [1, 2, 3, np.nan], "b": ["some", "strings", "here", "he"]})
630+
df["a"] = df["a"].astype("Int64")
631+
632+
res = df.replace("", np.nan)
633+
tm.assert_series_equal(res["a"], df["a"])
634+
627635
@pytest.mark.parametrize("dtype", ["boolean", "Int64", "Float64"])
628636
def test_replace_with_nullable_column(self, dtype):
629637
# GH-44499
@@ -1382,15 +1390,12 @@ def test_replace_value_category_type(self):
13821390

13831391
tm.assert_frame_equal(result, expected)
13841392

1385-
@pytest.mark.xfail(
1386-
reason="category dtype gets changed to object type after replace, see #35268",
1387-
raises=AssertionError,
1388-
)
13891393
def test_replace_dict_category_type(self):
13901394
"""
13911395
Test to ensure category dtypes are maintained
13921396
after replace with dict values
13931397
"""
1398+
# GH#35268, GH#44940
13941399

13951400
# create input dataframe
13961401
input_dict = {"col1": ["a"], "col2": ["obj1"], "col3": ["cat1"]}

pandas/tests/indexing/test_coercion.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,7 @@ def test_replace_series_datetime_tz(self, how, to_key, from_key, replacer):
11771177
assert obj.dtype == from_key
11781178

11791179
result = obj.replace(replacer)
1180+
11801181
exp = pd.Series(self.rep[to_key], index=index, name="yyy")
11811182
assert exp.dtype == to_key
11821183

@@ -1197,7 +1198,21 @@ def test_replace_series_datetime_datetime(self, how, to_key, from_key, replacer)
11971198
obj = pd.Series(self.rep[from_key], index=index, name="yyy")
11981199
assert obj.dtype == from_key
11991200

1200-
result = obj.replace(replacer)
1201+
warn = None
1202+
rep_ser = pd.Series(replacer)
1203+
if (
1204+
isinstance(obj.dtype, pd.DatetimeTZDtype)
1205+
and isinstance(rep_ser.dtype, pd.DatetimeTZDtype)
1206+
and obj.dtype != rep_ser.dtype
1207+
):
1208+
# mismatched tz DatetimeArray behavior will change to cast
1209+
# for setitem-like methods with mismatched tzs GH#44940
1210+
warn = FutureWarning
1211+
1212+
msg = "explicitly cast to object"
1213+
with tm.assert_produces_warning(warn, match=msg):
1214+
result = obj.replace(replacer)
1215+
12011216
exp = pd.Series(self.rep[to_key], index=index, name="yyy")
12021217
assert exp.dtype == to_key
12031218

pandas/tests/series/methods/test_replace.py

+109-19
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pandas as pd
77
import pandas._testing as tm
8+
from pandas.core.arrays import IntervalArray
89

910

1011
class TestSeriesReplace:
@@ -148,20 +149,21 @@ def test_replace_with_single_list(self):
148149
tm.assert_series_equal(s, ser)
149150

150151
def test_replace_mixed_types(self):
151-
s = pd.Series(np.arange(5), dtype="int64")
152+
ser = pd.Series(np.arange(5), dtype="int64")
152153

153154
def check_replace(to_rep, val, expected):
154-
sc = s.copy()
155-
r = s.replace(to_rep, val)
155+
sc = ser.copy()
156+
result = ser.replace(to_rep, val)
156157
return_value = sc.replace(to_rep, val, inplace=True)
157158
assert return_value is None
158-
tm.assert_series_equal(expected, r)
159+
tm.assert_series_equal(expected, result)
159160
tm.assert_series_equal(expected, sc)
160161

161-
# MUST upcast to float
162-
e = pd.Series([0.0, 1.0, 2.0, 3.0, 4.0])
162+
# 3.0 can still be held in our int64 series, so we do not upcast GH#44940
163163
tr, v = [3], [3.0]
164-
check_replace(tr, v, e)
164+
check_replace(tr, v, ser)
165+
# Note this matches what we get with the scalars 3 and 3.0
166+
check_replace(tr[0], v[0], ser)
165167

166168
# MUST upcast to float
167169
e = pd.Series([0, 1, 2, 3.5, 4])
@@ -257,10 +259,10 @@ def test_replace2(self):
257259
assert (ser[20:30] == -1).all()
258260

259261
def test_replace_with_dictlike_and_string_dtype(self, nullable_string_dtype):
260-
# GH 32621
261-
s = pd.Series(["one", "two", np.nan], dtype=nullable_string_dtype)
262-
expected = pd.Series(["1", "2", np.nan])
263-
result = s.replace({"one": "1", "two": "2"})
262+
# GH 32621, GH#44940
263+
ser = pd.Series(["one", "two", np.nan], dtype=nullable_string_dtype)
264+
expected = pd.Series(["1", "2", np.nan], dtype=nullable_string_dtype)
265+
result = ser.replace({"one": "1", "two": "2"})
264266
tm.assert_series_equal(expected, result)
265267

266268
def test_replace_with_empty_dictlike(self):
@@ -305,17 +307,18 @@ def test_replace_mixed_types_with_string(self):
305307
"categorical, numeric",
306308
[
307309
(pd.Categorical(["A"], categories=["A", "B"]), [1]),
308-
(pd.Categorical(("A",), categories=["A", "B"]), [1]),
309-
(pd.Categorical(("A", "B"), categories=["A", "B"]), [1, 2]),
310+
(pd.Categorical(["A", "B"], categories=["A", "B"]), [1, 2]),
310311
],
311312
)
312313
def test_replace_categorical(self, categorical, numeric):
313-
# GH 24971
314-
# Do not check if dtypes are equal due to a known issue that
315-
# Categorical.replace sometimes coerces to object (GH 23305)
316-
s = pd.Series(categorical)
317-
result = s.replace({"A": 1, "B": 2})
318-
expected = pd.Series(numeric)
314+
# GH 24971, GH#23305
315+
ser = pd.Series(categorical)
316+
result = ser.replace({"A": 1, "B": 2})
317+
expected = pd.Series(numeric).astype("category")
318+
if 2 not in expected.cat.categories:
319+
# i.e. categories should be [1, 2] even if there are no "B"s present
320+
# GH#44940
321+
expected = expected.cat.add_categories(2)
319322
tm.assert_series_equal(expected, result)
320323

321324
def test_replace_categorical_single(self):
@@ -514,3 +517,90 @@ def test_pandas_replace_na(self):
514517
result = ser.replace(regex_mapping, regex=True)
515518
exp = pd.Series(["CC", "CC", "CC-REPL", "DD", "CC", "", pd.NA], dtype="string")
516519
tm.assert_series_equal(result, exp)
520+
521+
@pytest.mark.parametrize(
522+
"dtype, input_data, to_replace, expected_data",
523+
[
524+
("bool", [True, False], {True: False}, [False, False]),
525+
("int64", [1, 2], {1: 10, 2: 20}, [10, 20]),
526+
("Int64", [1, 2], {1: 10, 2: 20}, [10, 20]),
527+
("float64", [1.1, 2.2], {1.1: 10.1, 2.2: 20.5}, [10.1, 20.5]),
528+
("Float64", [1.1, 2.2], {1.1: 10.1, 2.2: 20.5}, [10.1, 20.5]),
529+
("string", ["one", "two"], {"one": "1", "two": "2"}, ["1", "2"]),
530+
(
531+
pd.IntervalDtype("int64"),
532+
IntervalArray([pd.Interval(1, 2), pd.Interval(2, 3)]),
533+
{pd.Interval(1, 2): pd.Interval(10, 20)},
534+
IntervalArray([pd.Interval(10, 20), pd.Interval(2, 3)]),
535+
),
536+
(
537+
pd.IntervalDtype("float64"),
538+
IntervalArray([pd.Interval(1.0, 2.7), pd.Interval(2.8, 3.1)]),
539+
{pd.Interval(1.0, 2.7): pd.Interval(10.6, 20.8)},
540+
IntervalArray([pd.Interval(10.6, 20.8), pd.Interval(2.8, 3.1)]),
541+
),
542+
(
543+
pd.PeriodDtype("M"),
544+
[pd.Period("2020-05", freq="M")],
545+
{pd.Period("2020-05", freq="M"): pd.Period("2020-06", freq="M")},
546+
[pd.Period("2020-06", freq="M")],
547+
),
548+
],
549+
)
550+
def test_replace_dtype(self, dtype, input_data, to_replace, expected_data):
551+
# GH#33484
552+
ser = pd.Series(input_data, dtype=dtype)
553+
result = ser.replace(to_replace)
554+
expected = pd.Series(expected_data, dtype=dtype)
555+
tm.assert_series_equal(result, expected)
556+
557+
def test_replace_string_dtype(self):
558+
# GH#40732, GH#44940
559+
ser = pd.Series(["one", "two", np.nan], dtype="string")
560+
res = ser.replace({"one": "1", "two": "2"})
561+
expected = pd.Series(["1", "2", np.nan], dtype="string")
562+
tm.assert_series_equal(res, expected)
563+
564+
# GH#31644
565+
ser2 = pd.Series(["A", np.nan], dtype="string")
566+
res2 = ser2.replace("A", "B")
567+
expected2 = pd.Series(["B", np.nan], dtype="string")
568+
tm.assert_series_equal(res2, expected2)
569+
570+
ser3 = pd.Series(["A", "B"], dtype="string")
571+
res3 = ser3.replace("A", pd.NA)
572+
expected3 = pd.Series([pd.NA, "B"], dtype="string")
573+
tm.assert_series_equal(res3, expected3)
574+
575+
def test_replace_string_dtype_list_to_replace(self):
576+
# GH#41215, GH#44940
577+
ser = pd.Series(["abc", "def"], dtype="string")
578+
res = ser.replace(["abc", "any other string"], "xyz")
579+
expected = pd.Series(["xyz", "def"], dtype="string")
580+
tm.assert_series_equal(res, expected)
581+
582+
def test_replace_string_dtype_regex(self):
583+
# GH#31644
584+
ser = pd.Series(["A", "B"], dtype="string")
585+
res = ser.replace(r".", "C", regex=True)
586+
expected = pd.Series(["C", "C"], dtype="string")
587+
tm.assert_series_equal(res, expected)
588+
589+
def test_replace_nullable_numeric(self):
590+
# GH#40732, GH#44940
591+
592+
floats = pd.Series([1.0, 2.0, 3.999, 4.4], dtype=pd.Float64Dtype())
593+
assert floats.replace({1.0: 9}).dtype == floats.dtype
594+
assert floats.replace(1.0, 9).dtype == floats.dtype
595+
assert floats.replace({1.0: 9.0}).dtype == floats.dtype
596+
assert floats.replace(1.0, 9.0).dtype == floats.dtype
597+
598+
res = floats.replace(to_replace=[1.0, 2.0], value=[9.0, 10.0])
599+
assert res.dtype == floats.dtype
600+
601+
ints = pd.Series([1, 2, 3, 4], dtype=pd.Int64Dtype())
602+
assert ints.replace({1: 9}).dtype == ints.dtype
603+
assert ints.replace(1, 9).dtype == ints.dtype
604+
assert ints.replace({1: 9.0}).dtype == ints.dtype
605+
assert ints.replace(1, 9.0).dtype == ints.dtype
606+
# FIXME: ints.replace({1: 9.5}) raises bc of incorrect _can_hold_element

0 commit comments

Comments
 (0)