diff --git a/doc/source/whatsnew/v2.1.2.rst b/doc/source/whatsnew/v2.1.2.rst index a13a273b01257..6b9b233228cdb 100644 --- a/doc/source/whatsnew/v2.1.2.rst +++ b/doc/source/whatsnew/v2.1.2.rst @@ -43,6 +43,7 @@ Bug fixes - Fixed bug in :meth:`DataFrame.__setitem__` not inferring string dtype for zero-dimensional array with ``infer_string=True`` (:issue:`55366`) - Fixed bug in :meth:`DataFrame.idxmin` and :meth:`DataFrame.idxmax` raising for arrow dtypes (:issue:`55368`) - Fixed bug in :meth:`DataFrame.interpolate` raising incorrect error message (:issue:`55347`) +- Fixed bug in :meth:`DataFrame.update` bool dtype being converted to object (:issue:`55509`) - Fixed bug in :meth:`Index.insert` raising when inserting ``None`` into :class:`Index` with ``dtype="string[pyarrow_numpy]"`` (:issue:`55365`) - Fixed bug in :meth:`Series.all` and :meth:`Series.any` not treating missing values correctly for ``dtype="string[pyarrow_numpy]"`` (:issue:`55367`) - Fixed bug in :meth:`Series.floordiv` for :class:`ArrowDtype` (:issue:`55561`) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index df413fda0255a..e5e1a027e4730 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -8,6 +8,7 @@ alignment and a host of useful data manipulation methods having to do with the labeling information """ + from __future__ import annotations import collections @@ -8764,11 +8765,30 @@ def update( if not isinstance(other, DataFrame): other = DataFrame(other) - other = other.reindex(self.index) + indexes_intersection = other.index.intersection( + self.index + ) # order is important + if not len(indexes_intersection): + raise ValueError( + "Can't update dataframe when other has no index in common with " + "this dataframe." + ) + + if other.index.is_unique: + indexes_this = indexes_intersection + if self.index.is_unique: + indexes_that = indexes_intersection + else: + full_indexes_this = self.index.take( + self.index.get_indexer_for(indexes_intersection) + ) + indexes_that = indexes_intersection.reindex(full_indexes_this)[0] + else: + raise ValueError("Update not allowed with duplicate indexes on other.") for col in self.columns.intersection(other.columns): - this = self[col]._values - that = other[col]._values + this = self.loc[indexes_this, col]._values + that = other.loc[indexes_that, col]._values if filter_func is not None: mask = ~filter_func(this) | isna(that) @@ -8788,7 +8808,7 @@ def update( if mask.all(): continue - self.loc[:, col] = self[col].where(mask, that) + self.loc[indexes_this, col] = self.loc[indexes_this, col].where(mask, that) # ---------------------------------------------------------------------- # Data reshaping @@ -10218,9 +10238,11 @@ def _append( index = Index( [other.name], - name=self.index.names - if isinstance(self.index, MultiIndex) - else self.index.name, + name=( + self.index.names + if isinstance(self.index, MultiIndex) + else self.index.name + ), ) row_df = other.to_frame().T # infer_objects is needed for diff --git a/pandas/tests/frame/methods/test_update.py b/pandas/tests/frame/methods/test_update.py index 788c6220b2477..107642fc88eab 100644 --- a/pandas/tests/frame/methods/test_update.py +++ b/pandas/tests/frame/methods/test_update.py @@ -184,3 +184,47 @@ def test_update_dt_column_with_NaT_create_column(self): {"A": [1.0, 3.0], "B": [pd.NaT, pd.to_datetime("2016-01-01")]} ) tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "value_df, value_other, dtype", + [ + (True, False, bool), + (1, 2, int), + (np.uint64(1), np.uint(2), np.dtype("uint64")), + (1.0, 2.0, float), + (1.0 + 1j, 2.0 + 2j, complex), + ("a", "b", pd.StringDtype()), + ( + pd.to_timedelta("1 ms"), + pd.to_timedelta("2 ms"), + np.dtype("timedelta64[ns]"), + ), + ( + np.datetime64("2000-01-01T00:00:00"), + np.datetime64("2000-01-02T00:00:00"), + np.dtype("datetime64[ns]"), + ), + ], + ) + def test_update_preserve_dtype(self, value_df, value_other, dtype): + # GH#55509 + df = DataFrame({"a": [value_df] * 2}, index=[1, 2]) + other = DataFrame({"a": [value_other]}, index=[1]) + expected = DataFrame({"a": [value_other, value_df]}, index=[1, 2]) + df.update(other) + tm.assert_frame_equal(df, expected) + + def test_update_raises_on_duplicate_argument_index(self): + # GH#55509 + df = DataFrame({"a": [1, 1]}, index=[1, 2]) + other = DataFrame({"a": [2, 3]}, index=[1, 1]) + with pytest.raises(ValueError, match="duplicate index"): + df.update(other) + + def test_update_on_duplicate_frame_unique_argument_index(self): + # GH#55509 + df = DataFrame({"a": [1, 1, 1]}, index=[1, 1, 2]) + other = DataFrame({"a": [2, 3]}, index=[1, 2]) + expected = DataFrame({"a": [2, 2, 3]}, index=[1, 1, 2]) + df.update(other) + tm.assert_frame_equal(df, expected)