From 1c28ab41530ca4f9cba29b3ca6eaef98bdb18c63 Mon Sep 17 00:00:00 2001 From: Michael Tiemann <72577720+MichaelTiemannOSC@users.noreply.github.com> Date: Tue, 17 Oct 2023 06:58:22 -0400 Subject: [PATCH 1/3] Fix for issue #55509 Preserve dtype when updating from dataframe whose NA values don't affect original. I don't know the best place to put the test case in the tests/frame or the tests/dtype directory. Signed-off-by: Michael Tiemann <72577720+MichaelTiemannOSC@users.noreply.github.com> --- pandas/core/frame.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 09c43822e11e4..cb4ff0839c97e 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -81,6 +81,7 @@ construct_1d_arraylike_from_scalar, construct_2d_arraylike_from_scalar, find_common_type, + find_result_type, infer_dtype_from_scalar, invalidate_string_dtypes, maybe_box_native, @@ -8870,7 +8871,15 @@ def update( if mask.all(): continue - self.loc[:, col] = expressions.where(mask, this, that) + col_dtype = self[col].dtype + update_result = expressions.where(mask, this, that) + # Preserve dtype if udpate_result is all compatible with dtype + if col_dtype != object and update_result.dtype == object: + if all( + col_dtype == find_result_type(col_dtype, x) for x in update_result + ): + update_result = update_result.astype(col_dtype) + self.loc[:, col] = update_result # ---------------------------------------------------------------------- # Data reshaping From 0d3965fd0da2d47fe5130346750774e053a93099 Mon Sep 17 00:00:00 2001 From: Michael Tiemann <72577720+MichaelTiemannOSC@users.noreply.github.com> Date: Tue, 17 Oct 2023 20:08:28 -0400 Subject: [PATCH 2/3] Create test case Created generic test case that probes a number of basic types, not only `bool` (original report). Signed-off-by: Michael Tiemann <72577720+MichaelTiemannOSC@users.noreply.github.com> --- pandas/core/frame.py | 3 +- pandas/tests/frame/methods/test_update.py | 45 +++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index cb4ff0839c97e..6c3039ed2267d 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -101,6 +101,7 @@ is_integer_dtype, is_iterator, is_list_like, + is_object_dtype, is_scalar, is_sequence, needs_i8_conversion, @@ -8874,7 +8875,7 @@ def update( col_dtype = self[col].dtype update_result = expressions.where(mask, this, that) # Preserve dtype if udpate_result is all compatible with dtype - if col_dtype != object and update_result.dtype == object: + if col_dtype.kind in "?bBiufcSm" and is_object_dtype(update_result.dtype): if all( col_dtype == find_result_type(col_dtype, x) for x in update_result ): diff --git a/pandas/tests/frame/methods/test_update.py b/pandas/tests/frame/methods/test_update.py index 5738a25f26fcb..86abcaf254e7c 100644 --- a/pandas/tests/frame/methods/test_update.py +++ b/pandas/tests/frame/methods/test_update.py @@ -3,6 +3,8 @@ import pandas.util._test_decorators as td +from pandas.core.dtypes.common import pandas_dtype + import pandas as pd from pandas import ( DataFrame, @@ -177,3 +179,46 @@ def test_update_dt_column_with_NaT_create_column(self): {"A": [1.0, 3.0], "B": [pd.NaT, pd.to_datetime("2016-01-01")]} ) tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "value ,dtype", + [ + (True, pandas_dtype("bool")), + (1, pandas_dtype("int64")), + (np.uint64(2), pandas_dtype("uint64")), + (3.0, pandas_dtype("float")), + (4.0 + 1j, pandas_dtype("complex")), + ("a", pandas_dtype("string")), + (pd.to_timedelta("1 ms"), pandas_dtype("timedelta64[ns]")), + (np.datetime64("2000-01-01T00:00:00"), pandas_dtype("datetime64[ns]")), + ], + ) + def test_update_preserve_dtype(self, value, dtype): + # GH#55509 + df1 = ( + DataFrame( + { + "idx": [1, 2], + "val": [value] * 2, + } + ) + .set_index("idx") + .astype(dtype) + ) + df2 = ( + DataFrame( + { + "idx": [1], + "val": [value], + } + ) + .set_index("idx") + .astype(dtype) + ) + + assert df1.dtypes["val"] == dtype + assert df2.dtypes["val"] == dtype + + df1.update(df2) + + assert df1.dtypes["val"] == dtype From 095d01c6a03ea5c43fe510d3ca18275298c65ec2 Mon Sep 17 00:00:00 2001 From: Michael Tiemann <72577720+MichaelTiemannOSC@users.noreply.github.com> Date: Wed, 18 Oct 2023 08:12:51 -0400 Subject: [PATCH 3/3] Tighten constraints for fixing dtypes It appears that only `bool` and `datetime64` dtypes need special handling to pass `test_update_preserve_dtype`, so no need to see whether int, float, complex, string, or other types need fixing. Signed-off-by: Michael Tiemann <72577720+MichaelTiemannOSC@users.noreply.github.com> --- pandas/core/frame.py | 17 ++++++++++++----- pandas/tests/frame/methods/test_update.py | 1 + 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 2a9185205e328..8028b57e64e79 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -8886,11 +8886,18 @@ def update( col_dtype = self[col].dtype update_result = expressions.where(mask, this, that) # Preserve dtype if udpate_result is all compatible with dtype - if col_dtype.kind in "?bBiufcSm" and is_object_dtype(update_result.dtype): - if all( - col_dtype == find_result_type(col_dtype, x) for x in update_result - ): - update_result = update_result.astype(col_dtype) + # This only happens for `bool` and `datetime64` + if col_dtype.kind in "bM" and is_object_dtype(update_result.dtype): + try: + if all( + col_dtype == find_result_type(col_dtype, x) + for x in update_result + ): + update_result = update_result.astype(col_dtype) + except TypeError: + # Do nothing if we cannot interpret `col_dtype` as a data type + # e.g. `datetime64[ns, UTC]` + pass self.loc[:, col] = update_result # ---------------------------------------------------------------------- diff --git a/pandas/tests/frame/methods/test_update.py b/pandas/tests/frame/methods/test_update.py index 86abcaf254e7c..234961e4fd5fb 100644 --- a/pandas/tests/frame/methods/test_update.py +++ b/pandas/tests/frame/methods/test_update.py @@ -191,6 +191,7 @@ def test_update_dt_column_with_NaT_create_column(self): ("a", pandas_dtype("string")), (pd.to_timedelta("1 ms"), pandas_dtype("timedelta64[ns]")), (np.datetime64("2000-01-01T00:00:00"), pandas_dtype("datetime64[ns]")), + (pd.Timestamp("2000-01-01T00:00:00Z"), pandas_dtype("datetime64[ns, UTC]")), ], ) def test_update_preserve_dtype(self, value, dtype):