Skip to content

Commit 0c1d126

Browse files
committed
ENH: add argument to preserve dtypes of common columns in combine_first
1 parent 3bd3d1e commit 0c1d126

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

pandas/core/frame.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -6425,7 +6425,7 @@ def combine(
64256425
# convert_objects just in case
64266426
return self._constructor(result, index=new_index, columns=new_columns)
64276427

6428-
def combine_first(self, other: DataFrame) -> DataFrame:
6428+
def combine_first(self, other: DataFrame, preserve_dtypes: bool = False) -> DataFrame:
64296429
"""
64306430
Update null elements with value in the same location in `other`.
64316431
@@ -6438,6 +6438,11 @@ def combine_first(self, other: DataFrame) -> DataFrame:
64386438
other : DataFrame
64396439
Provided DataFrame to use to fill null values.
64406440
6441+
preserve_dtypes : bool, default False
6442+
try to preserve the column dtypes afetr combining
6443+
6444+
.. versionadded:: 1.2.1
6445+
64416446
Returns
64426447
-------
64436448
DataFrame
@@ -6482,7 +6487,18 @@ def combiner(x, y):
64826487

64836488
return expressions.where(mask, y_values, x_values)
64846489

6485-
return self.combine(other, combiner, overwrite=False)
6490+
combined = self.combine(other, combiner, overwrite=False)
6491+
6492+
if preserve_dtypes:
6493+
dtypes = {
6494+
col: find_common_type([self.dtypes[col], other.dtypes[col]])
6495+
for col in self.columns.intersection(other.columns)
6496+
}
6497+
6498+
if dtypes:
6499+
combined = combined.astype(dtypes)
6500+
6501+
return combined
64866502

64876503
def update(
64886504
self,

pandas/tests/frame/methods/test_combine_first.py

+44
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ def test_combine_first_mixed(self):
2424
combined = f.combine_first(g)
2525
tm.assert_frame_equal(combined, exp)
2626

27+
exp = DataFrame(
28+
{"A": list("abab"), "B": [0, 1, 0, 1]}, index=[0, 1, 5, 6]
29+
)
30+
combined = f.combine_first(g, preserve_dtypes=True)
31+
tm.assert_frame_equal(combined, exp)
32+
2733
def test_combine_first(self, float_frame):
2834
# disjoint
2935
head, tail = float_frame[:5], float_frame[5:]
@@ -363,9 +369,16 @@ def test_combine_first_int(self):
363369
expected_12 = DataFrame({"a": [0, 1, 3, 5]}, dtype="float64")
364370
tm.assert_frame_equal(result_12, expected_12)
365371

372+
result_12 = df1.combine_first(df2, preserve_dtypes=True)
373+
expected_12 = DataFrame({"a": [0, 1, 3, 5]})
374+
tm.assert_frame_equal(result_12, expected_12)
375+
366376
result_21 = df2.combine_first(df1)
367377
expected_21 = DataFrame({"a": [1, 4, 3, 5]}, dtype="float64")
378+
tm.assert_frame_equal(result_21, expected_21)
368379

380+
result_21 = df2.combine_first(df1, preserve_dtypes=True)
381+
expected_21 = DataFrame({"a": [1, 4, 3, 5]})
369382
tm.assert_frame_equal(result_21, expected_21)
370383

371384
@pytest.mark.parametrize("val", [1, 1.0])
@@ -439,3 +452,34 @@ def test_combine_first_with_nan_multiindex():
439452
index=mi_expected,
440453
)
441454
tm.assert_frame_equal(res, expected)
455+
456+
def test_combine_preserve_dtypes():
457+
a = Series(["a", "b"], index=range(2))
458+
b = Series(range(2), index=range(2))
459+
f = DataFrame({"A": a, "B": b})
460+
461+
c = Series(["a", "b"], index=range(5, 7))
462+
b = Series(range(-1, 1), index=range(5, 7))
463+
g = DataFrame({"B": b, "C": c})
464+
465+
exp = DataFrame(
466+
{
467+
"A": ["a", "b", np.nan, np.nan],
468+
"B": [0.0, 1.0, -1.0, 0.0],
469+
"C": [np.nan, np.nan, "a", "b"]
470+
},
471+
index=[0, 1, 5, 6]
472+
)
473+
combined = f.combine_first(g)
474+
tm.assert_frame_equal(combined, exp)
475+
476+
exp = DataFrame(
477+
{
478+
"A": ["a", "b", np.nan, np.nan],
479+
"B": [0, 1, -1, 0],
480+
"C": [np.nan, np.nan, "a", "b"]
481+
},
482+
index=[0, 1, 5, 6]
483+
)
484+
combined = f.combine_first(g, preserve_dtypes=True)
485+
tm.assert_frame_equal(combined, exp)

0 commit comments

Comments
 (0)