Skip to content

Commit b1aba2f

Browse files
aureliobarbosapmhatre1
authored andcommitted
BUG: dataframe.update coercing dtype (pandas-dev#57637)
1 parent e95c2dc commit b1aba2f

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

pandas/core/frame.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -8706,6 +8706,10 @@ def update(
87068706
dict.update : Similar method for dictionaries.
87078707
DataFrame.merge : For column(s)-on-column(s) operations.
87088708
8709+
Notes
8710+
-----
8711+
1. Duplicate indices on `other` are not supported and raises `ValueError`.
8712+
87098713
Examples
87108714
--------
87118715
>>> df = pd.DataFrame({"A": [1, 2, 3], "B": [400, 500, 600]})
@@ -8778,11 +8782,22 @@ def update(
87788782
if not isinstance(other, DataFrame):
87798783
other = DataFrame(other)
87808784

8781-
other = other.reindex(self.index)
8785+
if other.index.has_duplicates:
8786+
raise ValueError("Update not allowed with duplicate indexes on other.")
8787+
8788+
index_intersection = other.index.intersection(self.index)
8789+
if index_intersection.empty:
8790+
raise ValueError(
8791+
"Update not allowed when the index on `other` has no intersection "
8792+
"with this dataframe."
8793+
)
8794+
8795+
other = other.reindex(index_intersection)
8796+
this_data = self.loc[index_intersection]
87828797

87838798
for col in self.columns.intersection(other.columns):
8784-
this = self[col]._values
8785-
that = other[col]._values
8799+
this = this_data[col]
8800+
that = other[col]
87868801

87878802
if filter_func is not None:
87888803
mask = ~filter_func(this) | isna(that)
@@ -8802,7 +8817,7 @@ def update(
88028817
if mask.all():
88038818
continue
88048819

8805-
self.loc[:, col] = self[col].where(mask, that)
8820+
self.loc[index_intersection, col] = this.where(mask, that)
88068821

88078822
# ----------------------------------------------------------------------
88088823
# Data reshaping

pandas/tests/frame/methods/test_update.py

+52
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,55 @@ def test_update_dt_column_with_NaT_create_column(self):
184184
{"A": [1.0, 3.0], "B": [pd.NaT, pd.to_datetime("2016-01-01")]}
185185
)
186186
tm.assert_frame_equal(df, expected)
187+
188+
@pytest.mark.parametrize(
189+
"value_df, value_other, dtype",
190+
[
191+
(True, False, bool),
192+
(1, 2, int),
193+
(1.0, 2.0, float),
194+
(1.0 + 1j, 2.0 + 2j, complex),
195+
(np.uint64(1), np.uint(2), np.dtype("ubyte")),
196+
(np.uint64(1), np.uint(2), np.dtype("intc")),
197+
("a", "b", pd.StringDtype()),
198+
(
199+
pd.to_timedelta("1 ms"),
200+
pd.to_timedelta("2 ms"),
201+
np.dtype("timedelta64[ns]"),
202+
),
203+
(
204+
np.datetime64("2000-01-01T00:00:00"),
205+
np.datetime64("2000-01-02T00:00:00"),
206+
np.dtype("datetime64[ns]"),
207+
),
208+
],
209+
)
210+
def test_update_preserve_dtype(self, value_df, value_other, dtype):
211+
# GH#55509
212+
df = DataFrame({"a": [value_df] * 2}, index=[1, 2], dtype=dtype)
213+
other = DataFrame({"a": [value_other]}, index=[1], dtype=dtype)
214+
expected = DataFrame({"a": [value_other, value_df]}, index=[1, 2], dtype=dtype)
215+
df.update(other)
216+
tm.assert_frame_equal(df, expected)
217+
218+
def test_update_raises_on_duplicate_argument_index(self):
219+
# GH#55509
220+
df = DataFrame({"a": [1, 1]}, index=[1, 2])
221+
other = DataFrame({"a": [2, 3]}, index=[1, 1])
222+
with pytest.raises(ValueError, match="duplicate index"):
223+
df.update(other)
224+
225+
def test_update_raises_without_intersection(self):
226+
# GH#55509
227+
df = DataFrame({"a": [1]}, index=[1])
228+
other = DataFrame({"a": [2]}, index=[2])
229+
with pytest.raises(ValueError, match="no intersection"):
230+
df.update(other)
231+
232+
def test_update_on_duplicate_frame_unique_argument_index(self):
233+
# GH#55509
234+
df = DataFrame({"a": [1, 1, 1]}, index=[1, 1, 2], dtype=np.dtype("intc"))
235+
other = DataFrame({"a": [2, 3]}, index=[1, 2], dtype=np.dtype("intc"))
236+
expected = DataFrame({"a": [2, 2, 3]}, index=[1, 1, 2], dtype=np.dtype("intc"))
237+
df.update(other)
238+
tm.assert_frame_equal(df, expected)

0 commit comments

Comments
 (0)