Skip to content

Commit 2242fd9

Browse files
committed
Merge pull request #10637 from mortada/index_compare_tests
BUG: made behavior of operator equal for CategoricalIndex consistent,…
2 parents ebea3a3 + 81d9e0b commit 2242fd9

File tree

3 files changed

+75
-56
lines changed

3 files changed

+75
-56
lines changed

doc/source/whatsnew/v0.17.0.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ in the method call.
156156
Changes to Index Comparisons
157157
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
158158

159-
Operator equal on Index should behavior similarly to Series (:issue:`9947`)
159+
Operator equal on Index should behavior similarly to Series (:issue:`9947`, :issue:`10637`)
160160

161161
Starting in v0.17.0, comparing ``Index`` objects of different lengths will raise
162162
a ``ValueError``. This is to be consistent with the behavior of ``Series``.
@@ -390,7 +390,6 @@ Bug Fixes
390390

391391

392392

393-
- Bug in operator equal on Index not being consistent with Series (:issue:`9947`)
394393
- Reading "famafrench" data via ``DataReader`` results in HTTP 404 error because of the website url is changed (:issue:`10591`).
395394
- Bug in `read_msgpack` where DataFrame to decode has duplicate column names (:issue:`9618`)
396395

pandas/core/index.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3260,8 +3260,12 @@ def _evaluate_compare(self, other):
32603260
elif isinstance(other, Index):
32613261
other = self._create_categorical(self, other.values, categories=self.categories, ordered=self.ordered)
32623262

3263+
if isinstance(other, (ABCCategorical, np.ndarray, ABCSeries)):
3264+
if len(self.values) != len(other):
3265+
raise ValueError("Lengths must match to compare")
3266+
32633267
if isinstance(other, ABCCategorical):
3264-
if not (self.values.is_dtype_equal(other) and len(self.values) == len(other)):
3268+
if not self.values.is_dtype_equal(other):
32653269
raise TypeError("categorical index comparisions must have the same categories and ordered attributes")
32663270

32673271
return getattr(self.values, op)(other)

pandas/tests/test_index.py

+69-53
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,66 @@ def test_symmetric_diff(self):
396396
with tm.assertRaisesRegexp(TypeError, msg):
397397
result = first.sym_diff([1, 2, 3])
398398

399+
def test_equals_op(self):
400+
# GH9947, GH10637
401+
index_a = self.create_index()
402+
if isinstance(index_a, PeriodIndex):
403+
return
404+
405+
n = len(index_a)
406+
index_b = index_a[0:-1]
407+
index_c = index_a[0:-1].append(index_a[-2:-1])
408+
index_d = index_a[0:1]
409+
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
410+
index_a == index_b
411+
expected1 = np.array([True] * n)
412+
expected2 = np.array([True] * (n - 1) + [False])
413+
assert_numpy_array_equivalent(index_a == index_a, expected1)
414+
assert_numpy_array_equivalent(index_a == index_c, expected2)
415+
416+
# test comparisons with numpy arrays
417+
array_a = np.array(index_a)
418+
array_b = np.array(index_a[0:-1])
419+
array_c = np.array(index_a[0:-1].append(index_a[-2:-1]))
420+
array_d = np.array(index_a[0:1])
421+
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
422+
index_a == array_b
423+
assert_numpy_array_equivalent(index_a == array_a, expected1)
424+
assert_numpy_array_equivalent(index_a == array_c, expected2)
425+
426+
# test comparisons with Series
427+
series_a = Series(array_a)
428+
series_b = Series(array_b)
429+
series_c = Series(array_c)
430+
series_d = Series(array_d)
431+
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
432+
index_a == series_b
433+
assert_numpy_array_equivalent(index_a == series_a, expected1)
434+
assert_numpy_array_equivalent(index_a == series_c, expected2)
435+
436+
# cases where length is 1 for one of them
437+
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
438+
index_a == index_d
439+
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
440+
index_a == series_d
441+
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
442+
index_a == array_d
443+
with tm.assertRaisesRegexp(ValueError, "Series lengths must match"):
444+
series_a == series_d
445+
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
446+
series_a == array_d
447+
448+
# comparing with a scalar should broadcast; note that we are excluding
449+
# MultiIndex because in this case each item in the index is a tuple of
450+
# length 2, and therefore is considered an array of length 2 in the
451+
# comparison instead of a scalar
452+
if not isinstance(index_a, MultiIndex):
453+
expected3 = np.array([False] * (len(index_a) - 2) + [True, False])
454+
# assuming the 2nd to last item is unique in the data
455+
item = index_a[-2]
456+
assert_numpy_array_equivalent(index_a == item, expected3)
457+
assert_numpy_array_equivalent(series_a == item, expected3)
458+
399459

400460
class TestIndex(Base, tm.TestCase):
401461
_holder = Index
@@ -1548,54 +1608,7 @@ def test_groupby(self):
15481608
exp = {1: [0, 1], 2: [2, 3, 4]}
15491609
tm.assert_dict_equal(groups, exp)
15501610

1551-
def test_equals_op(self):
1552-
# GH9947
1553-
index_a = Index(['foo', 'bar', 'baz'])
1554-
index_b = Index(['foo', 'bar', 'baz', 'qux'])
1555-
index_c = Index(['foo', 'bar', 'qux'])
1556-
index_d = Index(['foo'])
1557-
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
1558-
index_a == index_b
1559-
assert_numpy_array_equivalent(index_a == index_a, np.array([True, True, True]))
1560-
assert_numpy_array_equivalent(index_a == index_c, np.array([True, True, False]))
1561-
1562-
# test comparisons with numpy arrays
1563-
array_a = np.array(['foo', 'bar', 'baz'])
1564-
array_b = np.array(['foo', 'bar', 'baz', 'qux'])
1565-
array_c = np.array(['foo', 'bar', 'qux'])
1566-
array_d = np.array(['foo'])
1567-
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
1568-
index_a == array_b
1569-
assert_numpy_array_equivalent(index_a == array_a, np.array([True, True, True]))
1570-
assert_numpy_array_equivalent(index_a == array_c, np.array([True, True, False]))
1571-
1572-
# test comparisons with Series
1573-
series_a = Series(['foo', 'bar', 'baz'])
1574-
series_b = Series(['foo', 'bar', 'baz', 'qux'])
1575-
series_c = Series(['foo', 'bar', 'qux'])
1576-
series_d = Series(['foo'])
1577-
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
1578-
index_a == series_b
1579-
assert_numpy_array_equivalent(index_a == series_a, np.array([True, True, True]))
1580-
assert_numpy_array_equivalent(index_a == series_c, np.array([True, True, False]))
1581-
1582-
# cases where length is 1 for one of them
1583-
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
1584-
index_a == index_d
1585-
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
1586-
index_a == series_d
1587-
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
1588-
index_a == array_d
1589-
with tm.assertRaisesRegexp(ValueError, "Series lengths must match"):
1590-
series_a == series_d
1591-
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
1592-
series_a == array_d
1593-
1594-
# comparing with scalar should broadcast
1595-
assert_numpy_array_equivalent(index_a == 'foo', np.array([True, False, False]))
1596-
assert_numpy_array_equivalent(series_a == 'foo', np.array([True, False, False]))
1597-
assert_numpy_array_equivalent(array_a == 'foo', np.array([True, False, False]))
1598-
1611+
def test_equals_op_multiindex(self):
15991612
# GH9785
16001613
# test comparisons of multiindex
16011614
from pandas.compat import StringIO
@@ -1609,6 +1622,8 @@ def test_equals_op(self):
16091622
mi3 = MultiIndex.from_tuples([(1, 2), (4, 5), (8, 9)])
16101623
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
16111624
df.index == mi3
1625+
1626+
index_a = Index(['foo', 'bar', 'baz'])
16121627
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
16131628
df.index == index_a
16141629
assert_numpy_array_equivalent(index_a == mi3, np.array([False, False, False]))
@@ -1966,7 +1981,8 @@ def test_equals(self):
19661981
self.assertTrue((ci1 == ci1.values).all())
19671982

19681983
# invalid comparisons
1969-
self.assertRaises(TypeError, lambda : ci1 == Index(['a','b','c']))
1984+
with tm.assertRaisesRegexp(ValueError, "Lengths must match"):
1985+
ci1 == Index(['a','b','c'])
19701986
self.assertRaises(TypeError, lambda : ci1 == ci2)
19711987
self.assertRaises(TypeError, lambda : ci1 == Categorical(ci1.values, ordered=False))
19721988
self.assertRaises(TypeError, lambda : ci1 == Categorical(ci1.values, categories=list('abc')))
@@ -2082,7 +2098,7 @@ def setUp(self):
20822098
self.setup_indices()
20832099

20842100
def create_index(self):
2085-
return Float64Index(np.arange(5,dtype='float64'))
2101+
return Float64Index(np.arange(5, dtype='float64'))
20862102

20872103
def test_repr_roundtrip(self):
20882104
for ind in (self.mixed, self.float):
@@ -2253,7 +2269,7 @@ def setUp(self):
22532269
self.setup_indices()
22542270

22552271
def create_index(self):
2256-
return Int64Index(np.arange(5,dtype='int64'))
2272+
return Int64Index(np.arange(5, dtype='int64'))
22572273

22582274
def test_too_many_names(self):
22592275
def testit():
@@ -2743,7 +2759,7 @@ def setUp(self):
27432759
self.setup_indices()
27442760

27452761
def create_index(self):
2746-
return date_range('20130101',periods=5)
2762+
return date_range('20130101', periods=5)
27472763

27482764
def test_pickle_compat_construction(self):
27492765
pass
@@ -2936,7 +2952,7 @@ def setUp(self):
29362952
self.setup_indices()
29372953

29382954
def create_index(self):
2939-
return pd.to_timedelta(range(5),unit='d') + pd.offsets.Hour(1)
2955+
return pd.to_timedelta(range(5), unit='d') + pd.offsets.Hour(1)
29402956

29412957
def test_get_loc(self):
29422958
idx = pd.to_timedelta(['0 days', '1 days', '2 days'])

0 commit comments

Comments
 (0)