Skip to content

ENH: try to preserve the dtype on combine_first for the case where the two DataFrame objects have the same columns #39051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 15, 2021
Merged
30 changes: 30 additions & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,36 @@ Notable bug fixes
These are bug fixes that might have notable behavior changes.


Preserve dtypes in :meth:`~pandas.DataFrame.combine_first`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

:meth:`~pandas.DataFrame.combine_first` will now preserve dtypes (:issue:`7509`)

.. ipython:: python

df1 = pd.DataFrame({"A": [1, 2, 3], "B": [1, 2, 3]}, index=[0, 1, 2])
df1
df2 = pd.DataFrame({"B": [4, 5, 6], "C": [1, 2, 3]}, index=[2, 3, 4])
df2
combined = df1.combine_first(df2)

*pandas 1.2.x*

.. code-block:: ipython

In [1]: combined.dtypes
Out[2]:
A float64
B float64
C float64
dtype: object

*pandas 1.3.0*

.. ipython:: python

combined.dtypes


.. _whatsnew_130.api_breaking.deps:

Expand Down
13 changes: 12 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6482,7 +6482,18 @@ def combiner(x, y):

return expressions.where(mask, y_values, x_values)

return self.combine(other, combiner, overwrite=False)
combined = self.combine(other, combiner, overwrite=False)

dtypes = {
col: find_common_type([self.dtypes[col], other.dtypes[col]])
for col in self.columns.intersection(other.columns)
if not is_dtype_equal(combined.dtypes[col], self.dtypes[col])
}

if dtypes:
combined = combined.astype(dtypes)

return combined

def update(
self,
Expand Down
83 changes: 63 additions & 20 deletions pandas/tests/frame/methods/test_combine_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import pytest

from pandas.core.dtypes.cast import find_common_type, is_dtype_equal

import pandas as pd
from pandas import DataFrame, Index, MultiIndex, Series
import pandas._testing as tm
Expand All @@ -18,9 +20,7 @@ def test_combine_first_mixed(self):
b = Series(range(2), index=range(5, 7))
g = DataFrame({"A": a, "B": b})

exp = DataFrame(
{"A": list("abab"), "B": [0.0, 1.0, 0.0, 1.0]}, index=[0, 1, 5, 6]
)
exp = DataFrame({"A": list("abab"), "B": [0, 1, 0, 1]}, index=[0, 1, 5, 6])
combined = f.combine_first(g)
tm.assert_frame_equal(combined, exp)

Expand Down Expand Up @@ -144,7 +144,7 @@ def test_combine_first_return_obj_type_with_bools(self):
)
df2 = DataFrame([[-42.6, np.nan, True], [-5.0, 1.6, False]], index=[1, 2])

expected = Series([True, True, False], name=2, dtype=object)
expected = Series([True, True, False], name=2, dtype=bool)

result_12 = df1.combine_first(df2)[2]
tm.assert_series_equal(result_12, expected)
Expand All @@ -157,22 +157,22 @@ def test_combine_first_return_obj_type_with_bools(self):
(
(
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
[None, None, None],
[pd.NaT, pd.NaT, pd.NaT],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
),
(
[None, None, None],
[pd.NaT, pd.NaT, pd.NaT],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
),
(
[datetime(2000, 1, 2), None, None],
[datetime(2000, 1, 2), pd.NaT, pd.NaT],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 2), datetime(2000, 1, 2), datetime(2000, 1, 3)],
),
(
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 2), None, None],
[datetime(2000, 1, 2), pd.NaT, pd.NaT],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
),
),
Expand All @@ -196,13 +196,13 @@ def test_combine_first_align_nan(self):

res = dfa.combine_first(dfb)
exp = DataFrame(
{"a": [pd.Timestamp("2011-01-01"), pd.NaT], "b": [2.0, 5.0]},
{"a": [pd.Timestamp("2011-01-01"), pd.NaT], "b": [2, 5]},
columns=["a", "b"],
)
tm.assert_frame_equal(res, exp)
assert res["a"].dtype == "datetime64[ns]"
# ToDo: this must be int64
assert res["b"].dtype == "float64"
assert res["b"].dtype == "int64"

res = dfa.iloc[:0].combine_first(dfb)
exp = DataFrame({"a": [np.nan, np.nan], "b": [4, 5]}, columns=["a", "b"])
Expand All @@ -219,14 +219,12 @@ def test_combine_first_timezone(self):
columns=["UTCdatetime", "abc"],
data=data1,
index=pd.date_range("20140627", periods=1),
dtype="object",
)
data2 = pd.to_datetime("20121212 12:12").tz_localize("UTC")
df2 = DataFrame(
columns=["UTCdatetime", "xyz"],
data=data2,
index=pd.date_range("20140628", periods=1),
dtype="object",
)
res = df2[["UTCdatetime"]].combine_first(df1)
exp = DataFrame(
Expand All @@ -239,13 +237,10 @@ def test_combine_first_timezone(self):
},
columns=["UTCdatetime", "abc"],
index=pd.date_range("20140627", periods=2, freq="D"),
dtype="object",
)
assert res["UTCdatetime"].dtype == "datetime64[ns, UTC]"
assert res["abc"].dtype == "datetime64[ns, UTC]"
# Need to cast all to "obejct" because combine_first does not retain dtypes:
# GH Issue 7509
res = res.astype("object")

tm.assert_frame_equal(res, exp)

# see gh-10567
Expand Down Expand Up @@ -360,12 +355,11 @@ def test_combine_first_int(self):
df2 = DataFrame({"a": [1, 4]}, dtype="int64")

result_12 = df1.combine_first(df2)
expected_12 = DataFrame({"a": [0, 1, 3, 5]}, dtype="float64")
expected_12 = DataFrame({"a": [0, 1, 3, 5]})
tm.assert_frame_equal(result_12, expected_12)

result_21 = df2.combine_first(df1)
expected_21 = DataFrame({"a": [1, 4, 3, 5]}, dtype="float64")

expected_21 = DataFrame({"a": [1, 4, 3, 5]})
tm.assert_frame_equal(result_21, expected_21)

@pytest.mark.parametrize("val", [1, 1.0])
Expand Down Expand Up @@ -404,11 +398,38 @@ def test_combine_first_string_dtype_only_na(self):
def test_combine_first_timestamp_bug(scalar1, scalar2, nulls_fixture):
# GH28481
na_value = nulls_fixture

frame = DataFrame([[na_value, na_value]], columns=["a", "b"])
other = DataFrame([[scalar1, scalar2]], columns=["b", "c"])

common_dtype = find_common_type([frame.dtypes["b"], other.dtypes["b"]])

if is_dtype_equal(common_dtype, "object") or frame.dtypes["b"] == other.dtypes["b"]:
val = scalar1
else:
val = na_value

result = frame.combine_first(other)

expected = DataFrame([[na_value, val, scalar2]], columns=["a", "b", "c"])

expected["b"] = expected["b"].astype(common_dtype)

tm.assert_frame_equal(result, expected)


def test_combine_first_timestamp_bug_NaT():
# GH28481
frame = DataFrame([[pd.NaT, pd.NaT]], columns=["a", "b"])
other = DataFrame(
[[datetime(2020, 1, 1), datetime(2020, 1, 2)]], columns=["b", "c"]
)

result = frame.combine_first(other)
expected = DataFrame([[na_value, scalar1, scalar2]], columns=["a", "b", "c"])
expected = DataFrame(
[[pd.NaT, datetime(2020, 1, 1), datetime(2020, 1, 2)]], columns=["a", "b", "c"]
)

tm.assert_frame_equal(result, expected)


Expand Down Expand Up @@ -439,3 +460,25 @@ def test_combine_first_with_nan_multiindex():
index=mi_expected,
)
tm.assert_frame_equal(res, expected)


def test_combine_preserve_dtypes():
# GH7509
a_column = Series(["a", "b"], index=range(2))
b_column = Series(range(2), index=range(2))
df1 = DataFrame({"A": a_column, "B": b_column})

c_column = Series(["a", "b"], index=range(5, 7))
b_column = Series(range(-1, 1), index=range(5, 7))
df2 = DataFrame({"B": b_column, "C": c_column})

expected = DataFrame(
{
"A": ["a", "b", np.nan, np.nan],
"B": [0, 1, -1, 0],
"C": [np.nan, np.nan, "a", "b"],
},
index=[0, 1, 5, 6],
)
combined = df1.combine_first(df2)
tm.assert_frame_equal(combined, expected)