Skip to content

Commit 7374a0d

Browse files
authored
BUG: MultiIndex.join losing dtype (#49877)
* BUG: MultiIndex.putmask losing ea dtype * Fix typing * Add asv * Simplify and add whatsnew * BUG: MultiIndex.join losing dtype * BUG: MultiIndex.join losing dtype * Add test * Add types
1 parent c872a8b commit 7374a0d

File tree

5 files changed

+74
-25
lines changed

5 files changed

+74
-25
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ MultiIndex
708708
- Bug in :meth:`MultiIndex.union` not sorting when sort=None and index contains missing values (:issue:`49010`)
709709
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
710710
- Bug in :meth:`MultiIndex.symmetric_difference` losing extension array (:issue:`48607`)
711+
- Bug in :meth:`MultiIndex.join` losing dtypes when :class:`MultiIndex` has duplicates (:issue:`49830`)
711712
- Bug in :meth:`MultiIndex.putmask` losing extension array (:issue:`49830`)
712713
- Bug in :meth:`MultiIndex.value_counts` returning a :class:`Series` indexed by flat index of tuples instead of a :class:`MultiIndex` (:issue:`49558`)
713714
-

pandas/core/indexes/base.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -4567,22 +4567,9 @@ def _join_non_unique(
45674567
)
45684568
mask = left_idx == -1
45694569

4570-
join_array = self._values.take(left_idx)
4571-
right = other._values.take(right_idx)
4572-
4573-
if isinstance(join_array, np.ndarray):
4574-
# error: Argument 3 to "putmask" has incompatible type
4575-
# "Union[ExtensionArray, ndarray[Any, Any]]"; expected
4576-
# "Union[_SupportsArray[dtype[Any]], _NestedSequence[
4577-
# _SupportsArray[dtype[Any]]], bool, int, float, complex,
4578-
# str, bytes, _NestedSequence[Union[bool, int, float,
4579-
# complex, str, bytes]]]"
4580-
np.putmask(join_array, mask, right) # type: ignore[arg-type]
4581-
else:
4582-
join_array._putmask(mask, right)
4583-
4584-
join_index = self._wrap_joined_index(join_array, other)
4585-
4570+
join_idx = self.take(left_idx)
4571+
right = other.take(right_idx)
4572+
join_index = join_idx.putmask(mask, right)
45864573
return join_index, left_idx, right_idx
45874574

45884575
@final
@@ -4744,8 +4731,8 @@ def _join_monotonic(
47444731
ret_index = other if how == "right" else self
47454732
return ret_index, None, None
47464733

4747-
ridx: np.ndarray | None
4748-
lidx: np.ndarray | None
4734+
ridx: npt.NDArray[np.intp] | None
4735+
lidx: npt.NDArray[np.intp] | None
47494736

47504737
if self.is_unique and other.is_unique:
47514738
# We can perform much better than the general case
@@ -4759,10 +4746,10 @@ def _join_monotonic(
47594746
ridx = None
47604747
elif how == "inner":
47614748
join_array, lidx, ridx = self._inner_indexer(other)
4762-
join_index = self._wrap_joined_index(join_array, other)
4749+
join_index = self._wrap_joined_index(join_array, other, lidx, ridx)
47634750
elif how == "outer":
47644751
join_array, lidx, ridx = self._outer_indexer(other)
4765-
join_index = self._wrap_joined_index(join_array, other)
4752+
join_index = self._wrap_joined_index(join_array, other, lidx, ridx)
47664753
else:
47674754
if how == "left":
47684755
join_array, lidx, ridx = self._left_indexer(other)
@@ -4773,20 +4760,33 @@ def _join_monotonic(
47734760
elif how == "outer":
47744761
join_array, lidx, ridx = self._outer_indexer(other)
47754762

4776-
join_index = self._wrap_joined_index(join_array, other)
4763+
assert lidx is not None
4764+
assert ridx is not None
4765+
4766+
join_index = self._wrap_joined_index(join_array, other, lidx, ridx)
47774767

47784768
lidx = None if lidx is None else ensure_platform_int(lidx)
47794769
ridx = None if ridx is None else ensure_platform_int(ridx)
47804770
return join_index, lidx, ridx
47814771

4782-
def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _IndexT:
4772+
def _wrap_joined_index(
4773+
self: _IndexT,
4774+
joined: ArrayLike,
4775+
other: _IndexT,
4776+
lidx: npt.NDArray[np.intp],
4777+
ridx: npt.NDArray[np.intp],
4778+
) -> _IndexT:
47834779
assert other.dtype == self.dtype
47844780

47854781
if isinstance(self, ABCMultiIndex):
47864782
name = self.names if self.names == other.names else None
47874783
# error: Incompatible return value type (got "MultiIndex",
47884784
# expected "_IndexT")
4789-
return self._constructor(joined, name=name) # type: ignore[return-value]
4785+
mask = lidx == -1
4786+
join_idx = self.take(lidx)
4787+
right = other.take(ridx)
4788+
join_index = join_idx.putmask(mask, right)
4789+
return join_index.set_names(name) # type: ignore[return-value]
47904790
else:
47914791
name = get_op_result_name(self, other)
47924792
return self._constructor._with_infer(joined, name=name, dtype=self.dtype)

pandas/core/indexes/datetimelike.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,11 @@ def _get_join_freq(self, other):
625625
freq = self.freq
626626
return freq
627627

628-
def _wrap_joined_index(self, joined, other):
628+
def _wrap_joined_index(
629+
self, joined, other, lidx: npt.NDArray[np.intp], ridx: npt.NDArray[np.intp]
630+
):
629631
assert other.dtype == self.dtype, (other.dtype, self.dtype)
630-
result = super()._wrap_joined_index(joined, other)
632+
result = super()._wrap_joined_index(joined, other, lidx, ridx)
631633
result._data._freq = self._get_join_freq(other)
632634
return result
633635

pandas/tests/frame/methods/test_combine_first.py

+14
Original file line numberDiff line numberDiff line change
@@ -543,3 +543,17 @@ def test_combine_first_int64_not_cast_to_float64():
543543
result = df_1.combine_first(df_2)
544544
expected = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [12, 34, 65]})
545545
tm.assert_frame_equal(result, expected)
546+
547+
548+
def test_midx_losing_dtype():
549+
# GH#49830
550+
midx = MultiIndex.from_arrays([[0, 0], [np.nan, np.nan]])
551+
midx2 = MultiIndex.from_arrays([[1, 1], [np.nan, np.nan]])
552+
df1 = DataFrame({"a": [None, 4]}, index=midx)
553+
df2 = DataFrame({"a": [3, 3]}, index=midx2)
554+
result = df1.combine_first(df2)
555+
expected_midx = MultiIndex.from_arrays(
556+
[[0, 0, 1, 1], [np.nan, np.nan, np.nan, np.nan]]
557+
)
558+
expected = DataFrame({"a": [np.nan, 4, 3, 3]}, index=expected_midx)
559+
tm.assert_frame_equal(result, expected)

pandas/tests/indexes/multi/test_join.py

+32
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,35 @@ def test_join_multi_with_nan():
225225
index=MultiIndex.from_product([["A"], [1.0, 2.0]], names=["id1", "id2"]),
226226
)
227227
tm.assert_frame_equal(result, expected)
228+
229+
230+
@pytest.mark.parametrize("val", [0, 5])
231+
def test_join_dtypes(any_numeric_ea_dtype, val):
232+
# GH#49830
233+
midx = MultiIndex.from_arrays([Series([1, 2], dtype=any_numeric_ea_dtype), [3, 4]])
234+
midx2 = MultiIndex.from_arrays(
235+
[Series([1, val, val], dtype=any_numeric_ea_dtype), [3, 4, 4]]
236+
)
237+
result = midx.join(midx2, how="outer")
238+
expected = MultiIndex.from_arrays(
239+
[Series([val, val, 1, 2], dtype=any_numeric_ea_dtype), [4, 4, 3, 4]]
240+
).sort_values()
241+
tm.assert_index_equal(result, expected)
242+
243+
244+
def test_join_dtypes_all_nan(any_numeric_ea_dtype):
245+
# GH#49830
246+
midx = MultiIndex.from_arrays(
247+
[Series([1, 2], dtype=any_numeric_ea_dtype), [np.nan, np.nan]]
248+
)
249+
midx2 = MultiIndex.from_arrays(
250+
[Series([1, 0, 0], dtype=any_numeric_ea_dtype), [np.nan, np.nan, np.nan]]
251+
)
252+
result = midx.join(midx2, how="outer")
253+
expected = MultiIndex.from_arrays(
254+
[
255+
Series([0, 0, 1, 2], dtype=any_numeric_ea_dtype),
256+
[np.nan, np.nan, np.nan, np.nan],
257+
]
258+
)
259+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)