diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 42170aaa09978..973c51307ed9c 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -705,6 +705,7 @@ MultiIndex - Bug in :meth:`MultiIndex.union` not sorting when sort=None and index contains missing values (:issue:`49010`) - Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`) - Bug in :meth:`MultiIndex.symmetric_difference` losing extension array (:issue:`48607`) +- Bug in :meth:`MultiIndex.join` losing dtypes when :class:`MultiIndex` has duplicates (:issue:`49830`) - Bug in :meth:`MultiIndex.putmask` losing extension array (:issue:`49830`) - Bug in :meth:`MultiIndex.value_counts` returning a :class:`Series` indexed by flat index of tuples instead of a :class:`MultiIndex` (:issue:`49558`) - diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 64bd66bb5fe0d..29cda0fc72d0b 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4572,22 +4572,9 @@ def _join_non_unique( ) mask = left_idx == -1 - join_array = self._values.take(left_idx) - right = other._values.take(right_idx) - - if isinstance(join_array, np.ndarray): - # error: Argument 3 to "putmask" has incompatible type - # "Union[ExtensionArray, ndarray[Any, Any]]"; expected - # "Union[_SupportsArray[dtype[Any]], _NestedSequence[ - # _SupportsArray[dtype[Any]]], bool, int, float, complex, - # str, bytes, _NestedSequence[Union[bool, int, float, - # complex, str, bytes]]]" - np.putmask(join_array, mask, right) # type: ignore[arg-type] - else: - join_array._putmask(mask, right) - - join_index = self._wrap_joined_index(join_array, other) - + join_idx = self.take(left_idx) + right = other.take(right_idx) + join_index = join_idx.putmask(mask, right) return join_index, left_idx, right_idx @final @@ -4749,8 +4736,8 @@ def _join_monotonic( ret_index = other if how == "right" else self return ret_index, None, None - ridx: np.ndarray | None - lidx: np.ndarray | None + ridx: npt.NDArray[np.intp] | None + lidx: npt.NDArray[np.intp] | None if self.is_unique and other.is_unique: # We can perform much better than the general case @@ -4764,10 +4751,10 @@ def _join_monotonic( ridx = None elif how == "inner": join_array, lidx, ridx = self._inner_indexer(other) - join_index = self._wrap_joined_index(join_array, other) + join_index = self._wrap_joined_index(join_array, other, lidx, ridx) elif how == "outer": join_array, lidx, ridx = self._outer_indexer(other) - join_index = self._wrap_joined_index(join_array, other) + join_index = self._wrap_joined_index(join_array, other, lidx, ridx) else: if how == "left": join_array, lidx, ridx = self._left_indexer(other) @@ -4778,20 +4765,33 @@ def _join_monotonic( elif how == "outer": join_array, lidx, ridx = self._outer_indexer(other) - join_index = self._wrap_joined_index(join_array, other) + assert lidx is not None + assert ridx is not None + + join_index = self._wrap_joined_index(join_array, other, lidx, ridx) lidx = None if lidx is None else ensure_platform_int(lidx) ridx = None if ridx is None else ensure_platform_int(ridx) return join_index, lidx, ridx - def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _IndexT: + def _wrap_joined_index( + self: _IndexT, + joined: ArrayLike, + other: _IndexT, + lidx: npt.NDArray[np.intp], + ridx: npt.NDArray[np.intp], + ) -> _IndexT: assert other.dtype == self.dtype if isinstance(self, ABCMultiIndex): name = self.names if self.names == other.names else None # error: Incompatible return value type (got "MultiIndex", # expected "_IndexT") - return self._constructor(joined, name=name) # type: ignore[return-value] + mask = lidx == -1 + join_idx = self.take(lidx) + right = other.take(ridx) + join_index = join_idx.putmask(mask, right) + return join_index.set_names(name) # type: ignore[return-value] else: name = get_op_result_name(self, other) return self._constructor._with_infer(joined, name=name, dtype=self.dtype) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index a6c396555e9a7..afe84befb6b31 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -625,9 +625,11 @@ def _get_join_freq(self, other): freq = self.freq return freq - def _wrap_joined_index(self, joined, other): + def _wrap_joined_index( + self, joined, other, lidx: npt.NDArray[np.intp], ridx: npt.NDArray[np.intp] + ): assert other.dtype == self.dtype, (other.dtype, self.dtype) - result = super()._wrap_joined_index(joined, other) + result = super()._wrap_joined_index(joined, other, lidx, ridx) result._data._freq = self._get_join_freq(other) return result diff --git a/pandas/tests/frame/methods/test_combine_first.py b/pandas/tests/frame/methods/test_combine_first.py index e838c8fabf456..ad6122501bc19 100644 --- a/pandas/tests/frame/methods/test_combine_first.py +++ b/pandas/tests/frame/methods/test_combine_first.py @@ -543,3 +543,17 @@ def test_combine_first_int64_not_cast_to_float64(): result = df_1.combine_first(df_2) expected = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [12, 34, 65]}) tm.assert_frame_equal(result, expected) + + +def test_midx_losing_dtype(): + # GH#49830 + midx = MultiIndex.from_arrays([[0, 0], [np.nan, np.nan]]) + midx2 = MultiIndex.from_arrays([[1, 1], [np.nan, np.nan]]) + df1 = DataFrame({"a": [None, 4]}, index=midx) + df2 = DataFrame({"a": [3, 3]}, index=midx2) + result = df1.combine_first(df2) + expected_midx = MultiIndex.from_arrays( + [[0, 0, 1, 1], [np.nan, np.nan, np.nan, np.nan]] + ) + expected = DataFrame({"a": [np.nan, 4, 3, 3]}, index=expected_midx) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/indexes/multi/test_join.py b/pandas/tests/indexes/multi/test_join.py index aa2f2ca5af7bd..4fff862961920 100644 --- a/pandas/tests/indexes/multi/test_join.py +++ b/pandas/tests/indexes/multi/test_join.py @@ -225,3 +225,35 @@ def test_join_multi_with_nan(): index=MultiIndex.from_product([["A"], [1.0, 2.0]], names=["id1", "id2"]), ) tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("val", [0, 5]) +def test_join_dtypes(any_numeric_ea_dtype, val): + # GH#49830 + midx = MultiIndex.from_arrays([Series([1, 2], dtype=any_numeric_ea_dtype), [3, 4]]) + midx2 = MultiIndex.from_arrays( + [Series([1, val, val], dtype=any_numeric_ea_dtype), [3, 4, 4]] + ) + result = midx.join(midx2, how="outer") + expected = MultiIndex.from_arrays( + [Series([val, val, 1, 2], dtype=any_numeric_ea_dtype), [4, 4, 3, 4]] + ).sort_values() + tm.assert_index_equal(result, expected) + + +def test_join_dtypes_all_nan(any_numeric_ea_dtype): + # GH#49830 + midx = MultiIndex.from_arrays( + [Series([1, 2], dtype=any_numeric_ea_dtype), [np.nan, np.nan]] + ) + midx2 = MultiIndex.from_arrays( + [Series([1, 0, 0], dtype=any_numeric_ea_dtype), [np.nan, np.nan, np.nan]] + ) + result = midx.join(midx2, how="outer") + expected = MultiIndex.from_arrays( + [ + Series([0, 0, 1, 2], dtype=any_numeric_ea_dtype), + [np.nan, np.nan, np.nan, np.nan], + ] + ) + tm.assert_index_equal(result, expected)