Skip to content

Commit 9c22b53

Browse files
committed
disable some CategoricalIndex comparisons
1 parent 0b7ae90 commit 9c22b53

File tree

2 files changed

+86
-28
lines changed

2 files changed

+86
-28
lines changed

pandas/core/index.py

+63-28
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,6 @@ def _try_get_item(x):
4545
except AttributeError:
4646
return x
4747

48-
def _indexOp(opname):
49-
"""
50-
Wrapper function for index comparison operations, to avoid
51-
code duplication.
52-
"""
53-
54-
def wrapper(self, other):
55-
func = getattr(self._data.view(np.ndarray), opname)
56-
result = func(np.asarray(other))
57-
58-
# technically we could support bool dtyped Index
59-
# for now just return the indexing array directly
60-
if is_bool_dtype(result):
61-
return result
62-
try:
63-
return Index(result)
64-
except: # pragma: no cover
65-
return result
66-
return wrapper
67-
68-
6948
class InvalidIndexError(Exception):
7049
pass
7150

@@ -1216,13 +1195,6 @@ def __sub__(self, other):
12161195
"use .difference()",FutureWarning)
12171196
return self.difference(other)
12181197

1219-
__eq__ = _indexOp('__eq__')
1220-
__ne__ = _indexOp('__ne__')
1221-
__lt__ = _indexOp('__lt__')
1222-
__gt__ = _indexOp('__gt__')
1223-
__le__ = _indexOp('__le__')
1224-
__ge__ = _indexOp('__ge__')
1225-
12261198
def __and__(self, other):
12271199
return self.intersection(other)
12281200

@@ -2380,6 +2352,34 @@ def _evaluate_with_timedelta_like(self, other, op, opstr):
23802352
def _evaluate_with_datetime_like(self, other, op, opstr):
23812353
raise TypeError("can only perform ops with datetime like values")
23822354

2355+
@classmethod
2356+
def _add_comparison_methods(cls):
2357+
""" add in comparison methods """
2358+
2359+
def _make_compare(op):
2360+
2361+
def _evaluate_compare(self, other):
2362+
func = getattr(self._data.view(np.ndarray), op)
2363+
result = func(np.asarray(other))
2364+
2365+
# technically we could support bool dtyped Index
2366+
# for now just return the indexing array directly
2367+
if is_bool_dtype(result):
2368+
return result
2369+
try:
2370+
return Index(result)
2371+
except: # pragma: no cover
2372+
return result
2373+
2374+
return _evaluate_compare
2375+
2376+
cls.__eq__ = _make_compare('__eq__')
2377+
cls.__ne__ = _make_compare('__ne__')
2378+
cls.__lt__ = _make_compare('__lt__')
2379+
cls.__gt__ = _make_compare('__gt__')
2380+
cls.__le__ = _make_compare('__le__')
2381+
cls.__ge__ = _make_compare('__ge__')
2382+
23832383
@classmethod
23842384
def _add_numeric_methods_disabled(cls):
23852385
""" add in numeric methods to disable """
@@ -2530,6 +2530,7 @@ def invalid_op(self, other=None):
25302530

25312531
Index._add_numeric_methods_disabled()
25322532
Index._add_logical_methods()
2533+
Index._add_comparison_methods()
25332534

25342535
class CategoricalIndex(Index):
25352536
"""
@@ -2817,8 +2818,42 @@ def convert(c):
28172818
cat = Categorical.from_codes(codes, categories=categories)
28182819
return CategoricalIndex(cat, name=name)
28192820

2821+
@classmethod
2822+
def _add_comparison_methods(cls):
2823+
""" add in comparison methods """
2824+
2825+
def _make_compare(op):
2826+
2827+
def _evaluate_compare(self, other):
2828+
2829+
# we must have only CategoricalIndexes here
2830+
if not isinstance(other, CategoricalIndex):
2831+
raise TypeError("cannot compare a non-categorical index vs a categorical index")
2832+
if not other.categories.equals(self.categories):
2833+
raise TypeError("categorical index comparisions must have the same categories")
2834+
2835+
return getattr(self.codes, op)(other.codes)
2836+
2837+
return _evaluate_compare
2838+
2839+
def _make_invalid_op(name):
2840+
2841+
def invalid_op(self, other=None):
2842+
raise TypeError("cannot perform {name} with this index type: {typ}".format(name=name,
2843+
typ=type(self)))
2844+
invalid_op.__name__ = name
2845+
return invalid_op
2846+
2847+
cls.__eq__ = _make_compare('__eq__')
2848+
cls.__ne__ = _make_compare('__ne__')
2849+
cls.__lt__ = _make_invalid_op('__lt__')
2850+
cls.__gt__ = _make_invalid_op('__gt__')
2851+
cls.__le__ = _make_invalid_op('__le__')
2852+
cls.__ge__ = _make_invalid_op('__ge__')
2853+
28202854
CategoricalIndex._add_numeric_methods_disabled()
28212855
CategoricalIndex._add_logical_methods_disabled()
2856+
CategoricalIndex._add_comparison_methods()
28222857

28232858
class NumericIndex(Index):
28242859
"""

pandas/tests/test_index.py

+23
Original file line numberDiff line numberDiff line change
@@ -1484,6 +1484,29 @@ def test_get_indexer(self):
14841484
self.assertRaises(NotImplementedError, lambda : idx2.get_indexer(idx1, method='backfill'))
14851485
self.assertRaises(NotImplementedError, lambda : idx2.get_indexer(idx1, method='nearest'))
14861486

1487+
def test_equals(self):
1488+
1489+
ci1 = CategoricalIndex(['a', 'b'], categories=['a', 'b'])
1490+
ci2 = CategoricalIndex(['a', 'b'], categories=['a', 'b', 'c'])
1491+
1492+
self.assertTrue(ci1.equals(ci1))
1493+
self.assertFalse(ci1.equals(ci2))
1494+
self.assertTrue((ci1 == ci1).all())
1495+
self.assertFalse((ci1 != ci1).all())
1496+
1497+
# invalid comparisons
1498+
self.assertRaises(TypeError, lambda : ci1 == 1)
1499+
self.assertRaises(TypeError, lambda : ci1 == Index(['a','b']))
1500+
self.assertRaises(TypeError, lambda : ci1 == Index(['a','b']))
1501+
self.assertRaises(TypeError, lambda : ci1 == ci2)
1502+
1503+
# invalid ops
1504+
self.assertRaises(TypeError, lambda : ci1 < ci1)
1505+
self.assertRaises(TypeError, lambda : ci1 <= ci1)
1506+
self.assertRaises(TypeError, lambda : ci1 > ci1)
1507+
self.assertRaises(TypeError, lambda : ci1 >= ci1)
1508+
1509+
14871510
class Numeric(Base):
14881511

14891512
def test_numeric_compat(self):

0 commit comments

Comments
 (0)