diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 70a5ac69011d1..8028b57e64e79 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -81,6 +81,7 @@ construct_1d_arraylike_from_scalar, construct_2d_arraylike_from_scalar, find_common_type, + find_result_type, infer_dtype_from_scalar, invalidate_string_dtypes, maybe_box_native, @@ -100,6 +101,7 @@ is_integer_dtype, is_iterator, is_list_like, + is_object_dtype, is_scalar, is_sequence, needs_i8_conversion, @@ -8881,7 +8883,22 @@ def update( if mask.all(): continue - self.loc[:, col] = expressions.where(mask, this, that) + col_dtype = self[col].dtype + update_result = expressions.where(mask, this, that) + # Preserve dtype if udpate_result is all compatible with dtype + # This only happens for `bool` and `datetime64` + if col_dtype.kind in "bM" and is_object_dtype(update_result.dtype): + try: + if all( + col_dtype == find_result_type(col_dtype, x) + for x in update_result + ): + update_result = update_result.astype(col_dtype) + except TypeError: + # Do nothing if we cannot interpret `col_dtype` as a data type + # e.g. `datetime64[ns, UTC]` + pass + self.loc[:, col] = update_result # ---------------------------------------------------------------------- # Data reshaping diff --git a/pandas/tests/frame/methods/test_update.py b/pandas/tests/frame/methods/test_update.py index 5738a25f26fcb..234961e4fd5fb 100644 --- a/pandas/tests/frame/methods/test_update.py +++ b/pandas/tests/frame/methods/test_update.py @@ -3,6 +3,8 @@ import pandas.util._test_decorators as td +from pandas.core.dtypes.common import pandas_dtype + import pandas as pd from pandas import ( DataFrame, @@ -177,3 +179,47 @@ def test_update_dt_column_with_NaT_create_column(self): {"A": [1.0, 3.0], "B": [pd.NaT, pd.to_datetime("2016-01-01")]} ) tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "value ,dtype", + [ + (True, pandas_dtype("bool")), + (1, pandas_dtype("int64")), + (np.uint64(2), pandas_dtype("uint64")), + (3.0, pandas_dtype("float")), + (4.0 + 1j, pandas_dtype("complex")), + ("a", pandas_dtype("string")), + (pd.to_timedelta("1 ms"), pandas_dtype("timedelta64[ns]")), + (np.datetime64("2000-01-01T00:00:00"), pandas_dtype("datetime64[ns]")), + (pd.Timestamp("2000-01-01T00:00:00Z"), pandas_dtype("datetime64[ns, UTC]")), + ], + ) + def test_update_preserve_dtype(self, value, dtype): + # GH#55509 + df1 = ( + DataFrame( + { + "idx": [1, 2], + "val": [value] * 2, + } + ) + .set_index("idx") + .astype(dtype) + ) + df2 = ( + DataFrame( + { + "idx": [1], + "val": [value], + } + ) + .set_index("idx") + .astype(dtype) + ) + + assert df1.dtypes["val"] == dtype + assert df2.dtypes["val"] == dtype + + df1.update(df2) + + assert df1.dtypes["val"] == dtype