Skip to content

Commit 3a6d742

Browse files
jbrockmendelKevin D Smith
authored and
Kevin D Smith
committed
BUG: MultiIndex comparison with tuple pandas-dev#21517 (pandas-dev#37170)
1 parent cff1146 commit 3a6d742

File tree

3 files changed

+50
-7
lines changed

3 files changed

+50
-7
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ Numeric
394394
- Bug in :meth:`DataFrame.__rmatmul__` error handling reporting transposed shapes (:issue:`21581`)
395395
- Bug in :class:`Series` flex arithmetic methods where the result when operating with a ``list``, ``tuple`` or ``np.ndarray`` would have an incorrect name (:issue:`36760`)
396396
- Bug in :class:`IntegerArray` multiplication with ``timedelta`` and ``np.timedelta64`` objects (:issue:`36870`)
397+
- Bug in :class:`MultiIndex` comparison with tuple incorrectly treating tuple as array-like (:issue:`21517`)
397398
- Bug in :meth:`DataFrame.diff` with ``datetime64`` dtypes including ``NaT`` values failing to fill ``NaT`` results correctly (:issue:`32441`)
398399
- Bug in :class:`DataFrame` arithmetic ops incorrectly accepting keyword arguments (:issue:`36843`)
399400
- Bug in :class:`IntervalArray` comparisons with :class:`Series` not returning :class:`Series` (:issue:`36908`)

pandas/core/indexes/base.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
)
6666
from pandas.core.dtypes.concat import concat_compat
6767
from pandas.core.dtypes.generic import (
68-
ABCCategorical,
6968
ABCDatetimeIndex,
7069
ABCMultiIndex,
7170
ABCPandasArray,
@@ -83,6 +82,7 @@
8382
from pandas.core.arrays.datetimes import tz_to_dtype, validate_tz_from_dtype
8483
from pandas.core.base import IndexOpsMixin, PandasObject
8584
import pandas.core.common as com
85+
from pandas.core.construction import extract_array
8686
from pandas.core.indexers import deprecate_ndim_indexing
8787
from pandas.core.indexes.frozen import FrozenList
8888
from pandas.core.ops import get_op_result_name
@@ -5376,11 +5376,13 @@ def _cmp_method(self, other, op):
53765376
if len(self) != len(other):
53775377
raise ValueError("Lengths must match to compare")
53785378

5379-
if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical):
5380-
left = type(other)(self._values, dtype=other.dtype)
5381-
return op(left, other)
5382-
elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
5383-
# e.g. PeriodArray
5379+
if not isinstance(other, ABCMultiIndex):
5380+
other = extract_array(other, extract_numpy=True)
5381+
else:
5382+
other = np.asarray(other)
5383+
5384+
if is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
5385+
# e.g. PeriodArray, Categorical
53845386
with np.errstate(all="ignore"):
53855387
result = op(self._values, other)
53865388

@@ -5395,7 +5397,7 @@ def _cmp_method(self, other, op):
53955397

53965398
else:
53975399
with np.errstate(all="ignore"):
5398-
result = ops.comparison_op(self._values, np.asarray(other), op)
5400+
result = ops.comparison_op(self._values, other, op)
53995401

54005402
return result
54015403

pandas/tests/indexes/multi/test_equivalence.py

+40
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,46 @@ def test_equals_op(idx):
8484
tm.assert_series_equal(series_a == item, Series(expected3))
8585

8686

87+
def test_compare_tuple():
88+
# GH#21517
89+
mi = MultiIndex.from_product([[1, 2]] * 2)
90+
91+
all_false = np.array([False, False, False, False])
92+
93+
result = mi == mi[0]
94+
expected = np.array([True, False, False, False])
95+
tm.assert_numpy_array_equal(result, expected)
96+
97+
result = mi != mi[0]
98+
tm.assert_numpy_array_equal(result, ~expected)
99+
100+
result = mi < mi[0]
101+
tm.assert_numpy_array_equal(result, all_false)
102+
103+
result = mi <= mi[0]
104+
tm.assert_numpy_array_equal(result, expected)
105+
106+
result = mi > mi[0]
107+
tm.assert_numpy_array_equal(result, ~expected)
108+
109+
result = mi >= mi[0]
110+
tm.assert_numpy_array_equal(result, ~all_false)
111+
112+
113+
def test_compare_tuple_strs():
114+
# GH#34180
115+
116+
mi = MultiIndex.from_tuples([("a", "b"), ("b", "c"), ("c", "a")])
117+
118+
result = mi == ("c", "a")
119+
expected = np.array([False, False, True])
120+
tm.assert_numpy_array_equal(result, expected)
121+
122+
result = mi == ("c",)
123+
expected = np.array([False, False, False])
124+
tm.assert_numpy_array_equal(result, expected)
125+
126+
87127
def test_equals_multi(idx):
88128
assert idx.equals(idx)
89129
assert not idx.equals(idx.values)

0 commit comments

Comments
 (0)