diff --git a/doc/source/whatsnew/v1.3.3.rst b/doc/source/whatsnew/v1.3.3.rst index b4265c1bc5ddd..00409cf963ab3 100644 --- a/doc/source/whatsnew/v1.3.3.rst +++ b/doc/source/whatsnew/v1.3.3.rst @@ -17,6 +17,7 @@ Fixed regressions - Fixed regression in :class:`DataFrame` constructor failing to broadcast for defined :class:`Index` and len one list of :class:`Timestamp` (:issue:`42810`) - Performance regression in :meth:`core.window.ewm.ExponentialMovingWindow.mean` (:issue:`42333`) - Fixed regression in :meth:`.GroupBy.agg` incorrectly raising in some cases (:issue:`42390`) +- Fixed regression in :meth:`merge` where ``on`` columns with ``ExtensionDtype`` or ``bool`` data types were cast to ``object`` in ``right`` and ``outer`` merge (:issue:`40073`) - Fixed regression in :meth:`RangeIndex.where` and :meth:`RangeIndex.putmask` raising ``AssertionError`` when result did not represent a :class:`RangeIndex` (:issue:`43240`) - Fixed regression in :meth:`read_parquet` where the ``fastparquet`` engine would not work properly with fastparquet 0.7.0 (:issue:`43075`) diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 62abbb11ee405..bdba1249ffafe 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -70,6 +70,7 @@ Categorical, Index, MultiIndex, + Series, ) from pandas.core import groupby import pandas.core.algorithms as algos @@ -81,10 +82,7 @@ from pandas.core.sorting import is_int64_overflow_possible if TYPE_CHECKING: - from pandas import ( - DataFrame, - Series, - ) + from pandas import DataFrame from pandas.core.arrays import DatetimeArray @@ -904,17 +902,22 @@ def _maybe_add_join_keys( # error: Item "bool" of "Union[Any, bool]" has no attribute "all" if mask_left.all(): # type: ignore[union-attr] key_col = Index(rvals) + result_dtype = rvals.dtype # error: Item "bool" of "Union[Any, bool]" has no attribute "all" elif ( right_indexer is not None and mask_right.all() # type: ignore[union-attr] ): key_col = Index(lvals) + result_dtype = lvals.dtype else: key_col = Index(lvals).where(~mask_left, rvals) + result_dtype = lvals.dtype if result._is_label_reference(name): - result[name] = key_col + result[name] = Series( + key_col, dtype=result_dtype, index=result.index + ) elif result._is_level_reference(name): if isinstance(result.index, MultiIndex): key_col.name = name diff --git a/pandas/tests/reshape/merge/test_merge.py b/pandas/tests/reshape/merge/test_merge.py index 37bbd8553d1b2..71134dcaf9ccc 100644 --- a/pandas/tests/reshape/merge/test_merge.py +++ b/pandas/tests/reshape/merge/test_merge.py @@ -356,8 +356,8 @@ def test_merge_join_key_dtype_cast(self): df = merge(df1, df2, how="outer") # GH13169 - # this really should be bool - assert df["key"].dtype == "object" + # GH#40073 + assert df["key"].dtype == "bool" df1 = DataFrame({"val": [1]}) df2 = DataFrame({"val": [2]}) @@ -368,10 +368,12 @@ def test_merge_join_key_dtype_cast(self): def test_handle_join_key_pass_array(self): left = DataFrame( - {"key": [1, 1, 2, 2, 3], "value": np.arange(5)}, columns=["value", "key"] + {"key": [1, 1, 2, 2, 3], "value": np.arange(5)}, + columns=["value", "key"], + dtype="int64", ) - right = DataFrame({"rvalue": np.arange(6)}) - key = np.array([1, 1, 2, 3, 4, 5]) + right = DataFrame({"rvalue": np.arange(6)}, dtype="int64") + key = np.array([1, 1, 2, 3, 4, 5], dtype="int64") merged = merge(left, right, left_on="key", right_on=key, how="outer") merged2 = merge(right, left, left_on=key, right_on="key", how="outer") @@ -1644,6 +1646,57 @@ def test_merge_incompat_dtypes_error(self, df1_vals, df2_vals): with pytest.raises(ValueError, match=msg): merge(df2, df1, on=["A"]) + @pytest.mark.parametrize( + "expected_data, how", + [ + ([1, 2], "outer"), + ([], "inner"), + ([2], "right"), + ([1], "left"), + ], + ) + def test_merge_EA_dtype(self, any_numeric_ea_dtype, how, expected_data): + # GH#40073 + d1 = DataFrame([(1,)], columns=["id"], dtype=any_numeric_ea_dtype) + d2 = DataFrame([(2,)], columns=["id"], dtype=any_numeric_ea_dtype) + result = merge(d1, d2, how=how) + expected = DataFrame(expected_data, columns=["id"], dtype=any_numeric_ea_dtype) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "expected_data, how", + [ + (["a", "b"], "outer"), + ([], "inner"), + (["b"], "right"), + (["a"], "left"), + ], + ) + def test_merge_string_dtype(self, how, expected_data, any_string_dtype): + # GH#40073 + d1 = DataFrame([("a",)], columns=["id"], dtype=any_string_dtype) + d2 = DataFrame([("b",)], columns=["id"], dtype=any_string_dtype) + result = merge(d1, d2, how=how) + expected = DataFrame(expected_data, columns=["id"], dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "how, expected_data", + [ + ("inner", [[True, 1, 4], [False, 5, 3]]), + ("outer", [[True, 1, 4], [False, 5, 3]]), + ("left", [[True, 1, 4], [False, 5, 3]]), + ("right", [[False, 5, 3], [True, 1, 4]]), + ], + ) + def test_merge_bool_dtype(self, how, expected_data): + # GH#40073 + df1 = DataFrame({"A": [True, False], "B": [1, 5]}) + df2 = DataFrame({"A": [False, True], "C": [3, 4]}) + result = merge(df1, df2, how=how) + expected = DataFrame(expected_data, columns=["A", "B", "C"]) + tm.assert_frame_equal(result, expected) + @pytest.fixture def left():