Skip to content

ENH: try to preserve the dtype on combine_first for the case where the two DataFrame objects have the same columns #39051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 15, 2021
Merged
27 changes: 26 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6482,7 +6482,32 @@ def combiner(x, y):

return expressions.where(mask, y_values, x_values)

return self.combine(other, combiner, overwrite=False)
combined = self.combine(other, combiner, overwrite=False)

dtypes = {}

for col in self.columns.intersection(other.columns):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a simple list-comprehension

try:
# if the column has different dtype in the
# DataFrame objects then add the common dtype
# to the columns dtype conversion dict
if combined.dtypes[col] != self.dtypes[col]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use is_dtype_equal here

dtypes[col] = find_common_type(
[self.dtypes[col], other.dtypes[col]]
)
except TypeError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not want to do multiple try/excepts ever as these tend to hide errors.
in fact you should not need here at all. find_common_type will always succeed (it could of course be object).

# numpy dtype was compared with pandas dtype
try:
# just try to apply the initial column dtype
combined[col] = combined[col].astype(self.dtypes[col])
except ValueError:
# could not apply the initial dtype, so skip
pass

if dtypes:
combined = combined.astype(dtypes)

return combined

def update(
self,
Expand Down
85 changes: 65 additions & 20 deletions pandas/tests/frame/methods/test_combine_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import pytest

from pandas.core.dtypes.cast import find_common_type

import pandas as pd
from pandas import DataFrame, Index, MultiIndex, Series
import pandas._testing as tm
Expand All @@ -18,9 +20,7 @@ def test_combine_first_mixed(self):
b = Series(range(2), index=range(5, 7))
g = DataFrame({"A": a, "B": b})

exp = DataFrame(
{"A": list("abab"), "B": [0.0, 1.0, 0.0, 1.0]}, index=[0, 1, 5, 6]
)
exp = DataFrame({"A": list("abab"), "B": [0, 1, 0, 1]}, index=[0, 1, 5, 6])
combined = f.combine_first(g)
tm.assert_frame_equal(combined, exp)

Expand Down Expand Up @@ -144,7 +144,7 @@ def test_combine_first_return_obj_type_with_bools(self):
)
df2 = DataFrame([[-42.6, np.nan, True], [-5.0, 1.6, False]], index=[1, 2])

expected = Series([True, True, False], name=2, dtype=object)
expected = Series([True, True, False], name=2, dtype=bool)

result_12 = df1.combine_first(df2)[2]
tm.assert_series_equal(result_12, expected)
Expand All @@ -157,22 +157,22 @@ def test_combine_first_return_obj_type_with_bools(self):
(
(
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
[None, None, None],
[pd.NaT, pd.NaT, pd.NaT],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
),
(
[None, None, None],
[pd.NaT, pd.NaT, pd.NaT],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
),
(
[datetime(2000, 1, 2), None, None],
[datetime(2000, 1, 2), pd.NaT, pd.NaT],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 2), datetime(2000, 1, 2), datetime(2000, 1, 3)],
),
(
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 2), None, None],
[datetime(2000, 1, 2), pd.NaT, pd.NaT],
[datetime(2000, 1, 1), datetime(2000, 1, 2), datetime(2000, 1, 3)],
),
),
Expand All @@ -196,13 +196,13 @@ def test_combine_first_align_nan(self):

res = dfa.combine_first(dfb)
exp = DataFrame(
{"a": [pd.Timestamp("2011-01-01"), pd.NaT], "b": [2.0, 5.0]},
{"a": [pd.Timestamp("2011-01-01"), pd.NaT], "b": [2, 5]},
columns=["a", "b"],
)
tm.assert_frame_equal(res, exp)
assert res["a"].dtype == "datetime64[ns]"
# ToDo: this must be int64
assert res["b"].dtype == "float64"
assert res["b"].dtype == "int64"

res = dfa.iloc[:0].combine_first(dfb)
exp = DataFrame({"a": [np.nan, np.nan], "b": [4, 5]}, columns=["a", "b"])
Expand All @@ -219,14 +219,12 @@ def test_combine_first_timezone(self):
columns=["UTCdatetime", "abc"],
data=data1,
index=pd.date_range("20140627", periods=1),
dtype="object",
)
data2 = pd.to_datetime("20121212 12:12").tz_localize("UTC")
df2 = DataFrame(
columns=["UTCdatetime", "xyz"],
data=data2,
index=pd.date_range("20140628", periods=1),
dtype="object",
)
res = df2[["UTCdatetime"]].combine_first(df1)
exp = DataFrame(
Expand All @@ -239,13 +237,10 @@ def test_combine_first_timezone(self):
},
columns=["UTCdatetime", "abc"],
index=pd.date_range("20140627", periods=2, freq="D"),
dtype="object",
)
assert res["UTCdatetime"].dtype == "datetime64[ns, UTC]"
assert res["abc"].dtype == "datetime64[ns, UTC]"
# Need to cast all to "obejct" because combine_first does not retain dtypes:
# GH Issue 7509
res = res.astype("object")

tm.assert_frame_equal(res, exp)

# see gh-10567
Expand Down Expand Up @@ -360,12 +355,11 @@ def test_combine_first_int(self):
df2 = DataFrame({"a": [1, 4]}, dtype="int64")

result_12 = df1.combine_first(df2)
expected_12 = DataFrame({"a": [0, 1, 3, 5]}, dtype="float64")
expected_12 = DataFrame({"a": [0, 1, 3, 5]})
tm.assert_frame_equal(result_12, expected_12)

result_21 = df2.combine_first(df1)
expected_21 = DataFrame({"a": [1, 4, 3, 5]}, dtype="float64")

expected_21 = DataFrame({"a": [1, 4, 3, 5]})
tm.assert_frame_equal(result_21, expected_21)

@pytest.mark.parametrize("val", [1, 1.0])
Expand Down Expand Up @@ -404,11 +398,41 @@ def test_combine_first_string_dtype_only_na(self):
def test_combine_first_timestamp_bug(scalar1, scalar2, nulls_fixture):
# GH28481
na_value = nulls_fixture

frame = DataFrame([[na_value, na_value]], columns=["a", "b"])
other = DataFrame([[scalar1, scalar2]], columns=["b", "c"])

try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use try/except in tests. explicity specify the expected

common_dtype = find_common_type([frame.dtypes["b"], other.dtypes["b"]])
except TypeError:
common_dtype = "object"

if common_dtype == "object" or frame.dtypes["b"] == other.dtypes["b"]:
val = scalar1
else:
val = na_value

result = frame.combine_first(other)
expected = DataFrame([[na_value, scalar1, scalar2]], columns=["a", "b", "c"])

expected = DataFrame([[na_value, val, scalar2]], columns=["a", "b", "c"])

expected["b"] = expected["b"].astype(common_dtype)

tm.assert_frame_equal(result, expected)


def test_combine_first_timestamp_bug_NaT():
# GH28481
frame = DataFrame([[pd.NaT, pd.NaT]], columns=["a", "b"])
other = DataFrame(
[[datetime(2020, 1, 1), datetime(2020, 1, 2)]], columns=["b", "c"]
)

result = frame.combine_first(other)
expected = DataFrame(
[[pd.NaT, datetime(2020, 1, 1), datetime(2020, 1, 2)]], columns=["a", "b", "c"]
)

tm.assert_frame_equal(result, expected)


Expand Down Expand Up @@ -439,3 +463,24 @@ def test_combine_first_with_nan_multiindex():
index=mi_expected,
)
tm.assert_frame_equal(res, expected)


def test_combine_preserve_dtypes():
a = Series(["a", "b"], index=range(2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the issue number as a comment

b = Series(range(2), index=range(2))
f = DataFrame({"A": a, "B": b})

c = Series(["a", "b"], index=range(5, 7))
b = Series(range(-1, 1), index=range(5, 7))
g = DataFrame({"B": b, "C": c})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: can you avoid 1-letter variable names? makes it harder to grep for things


exp = DataFrame(
{
"A": ["a", "b", np.nan, np.nan],
"B": [0, 1, -1, 0],
"C": [np.nan, np.nan, "a", "b"],
},
index=[0, 1, 5, 6],
)
combined = f.combine_first(g)
tm.assert_frame_equal(combined, exp)