Skip to content

Commit 73c44a8

Browse files
committed
Merge pull request #8946 from jreback/cat_equal
API: Allow equality comparisons of Series with a categorical dtype and object type are allowed (GH8938)
2 parents 4805582 + a4843f0 commit 73c44a8

File tree

5 files changed

+112
-36
lines changed

5 files changed

+112
-36
lines changed

doc/source/categorical.rst

+25-12
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,23 @@ Reordering the ``categories``, changes a future sort.
353353
Comparisons
354354
-----------
355355

356-
Comparing `Categoricals` with other objects is possible in two cases:
356+
Comparing categorical data with other objects is possible in three cases:
357357

358-
* comparing a categorical Series to another categorical Series, when `categories` and `ordered` is
359-
the same or
360-
* comparing a categorical Series to a scalar.
358+
* comparing equality (``==`` and ``!=``) to a list-like object (list, Series, array,
359+
...) of the same length as the categorical data or
360+
* all comparisons (``==``, ``!=``, ``>``, ``>=``, ``<``, and ``<=``) of categorical data to
361+
another categorical Series, when ``ordered==True`` and the `categories` are the same or
362+
* all comparisons of a categorical data to a scalar.
361363

362-
All other comparisons will raise a TypeError.
364+
All other comparisons, especially "non-equality" comparisons of two categoricals with different
365+
categories or a categorical with any list-like object, will raise a TypeError.
366+
367+
.. note::
368+
369+
Any "non-equality" comparisons of categorical data with a `Series`, `np.array`, `list` or
370+
categorical data with different categories or ordering will raise an `TypeError` because custom
371+
categories ordering could be interpreted in two ways: one with taking in account the
372+
ordering and one without.
363373

364374
.. ipython:: python
365375
@@ -378,6 +388,13 @@ Comparing to a categorical with the same categories and ordering or to a scalar
378388
cat > cat_base
379389
cat > 2
380390
391+
Equality comparisons work with any list-like object of same length and scalars:
392+
393+
.. ipython:: python
394+
395+
cat == cat_base2
396+
cat == 2
397+
381398
This doesn't work because the categories are not the same:
382399

383400
.. ipython:: python
@@ -387,13 +404,9 @@ This doesn't work because the categories are not the same:
387404
except TypeError as e:
388405
print("TypeError: " + str(e))
389406
390-
.. note::
391-
392-
Comparisons with `Series`, `np.array` or a `Categorical` with different categories or ordering
393-
will raise an `TypeError` because custom categories ordering could be interpreted in two ways:
394-
one with taking in account the ordering and one without. If you want to compare a categorical
395-
series with such a type, you need to be explicit and convert the categorical data back to the
396-
original values:
407+
If you want to do a "non-equality" comparison of a categorical series with a list-like object
408+
which is not categorical data, you need to be explicit and convert the categorical data back to
409+
the original values:
397410

398411
.. ipython:: python
399412

doc/source/whatsnew/v0.15.2.txt

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ API changes
5959
p = pd.Panel(np.random.rand(2, 5, 4) > 0.1)
6060
p.all()
6161

62+
- Allow equality comparisons of Series with a categorical dtype and object dtype; previously these would raise ``TypeError`` (:issue:`8938`)
63+
6264
.. _whatsnew_0152.enhancements:
6365

6466
Enhancements

pandas/core/categorical.py

+6
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def f(self, other):
6464
else:
6565
return np.repeat(False, len(self))
6666
else:
67+
68+
# allow categorical vs object dtype array comparisons for equality
69+
# these are only positional comparisons
70+
if op in ['__eq__','__ne__']:
71+
return getattr(np.array(self),op)(np.array(other))
72+
6773
msg = "Cannot compare a Categorical for op {op} with type {typ}. If you want to \n" \
6874
"compare values, use 'np.asarray(cat) <op> other'."
6975
raise TypeError(msg.format(op=op,typ=type(other)))

pandas/core/ops.py

+26-23
Original file line numberDiff line numberDiff line change
@@ -541,10 +541,13 @@ def _comp_method_SERIES(op, name, str_rep, masker=False):
541541
"""
542542
def na_op(x, y):
543543

544-
if com.is_categorical_dtype(x) != (not np.isscalar(y) and com.is_categorical_dtype(y)):
545-
msg = "Cannot compare a Categorical for op {op} with type {typ}. If you want to \n" \
546-
"compare values, use 'series <op> np.asarray(cat)'."
547-
raise TypeError(msg.format(op=op,typ=type(y)))
544+
# dispatch to the categorical if we have a categorical
545+
# in either operand
546+
if com.is_categorical_dtype(x):
547+
return op(x,y)
548+
elif com.is_categorical_dtype(y) and not lib.isscalar(y):
549+
return op(y,x)
550+
548551
if x.dtype == np.object_:
549552
if isinstance(y, list):
550553
y = lib.list_to_object_array(y)
@@ -586,33 +589,33 @@ def wrapper(self, other):
586589
msg = "Cannot compare a Categorical for op {op} with Series of dtype {typ}.\n"\
587590
"If you want to compare values, use 'series <op> np.asarray(other)'."
588591
raise TypeError(msg.format(op=op,typ=self.dtype))
589-
else:
590592

591-
mask = isnull(self)
592593

593-
values = self.get_values()
594-
other = _index.convert_scalar(values,_values_from_object(other))
594+
mask = isnull(self)
595595

596-
if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
597-
values = values.view('i8')
596+
values = self.get_values()
597+
other = _index.convert_scalar(values,_values_from_object(other))
598598

599-
# scalars
600-
res = na_op(values, other)
601-
if np.isscalar(res):
602-
raise TypeError('Could not compare %s type with Series'
603-
% type(other))
599+
if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
600+
values = values.view('i8')
604601

605-
# always return a full value series here
606-
res = _values_from_object(res)
602+
# scalars
603+
res = na_op(values, other)
604+
if np.isscalar(res):
605+
raise TypeError('Could not compare %s type with Series'
606+
% type(other))
607607

608-
res = pd.Series(res, index=self.index, name=self.name,
609-
dtype='bool')
608+
# always return a full value series here
609+
res = _values_from_object(res)
610610

611-
# mask out the invalids
612-
if mask.any():
613-
res[mask] = masker
611+
res = pd.Series(res, index=self.index, name=self.name,
612+
dtype='bool')
613+
614+
# mask out the invalids
615+
if mask.any():
616+
res[mask] = masker
614617

615-
return res
618+
return res
616619
return wrapper
617620

618621

pandas/tests/test_categorical.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -2212,11 +2212,63 @@ def f():
22122212
tm.assert_series_equal(res, exp)
22132213

22142214
# And test NaN handling...
2215-
cat = pd.Series(pd.Categorical(["a","b","c", np.nan]))
2215+
cat = Series(Categorical(["a","b","c", np.nan]))
22162216
exp = Series([True, True, True, False])
22172217
res = (cat == cat)
22182218
tm.assert_series_equal(res, exp)
22192219

2220+
def test_cat_equality(self):
2221+
2222+
# GH 8938
2223+
# allow equality comparisons
2224+
a = Series(list('abc'),dtype="category")
2225+
b = Series(list('abc'),dtype="object")
2226+
c = Series(['a','b','cc'],dtype="object")
2227+
d = Series(list('acb'),dtype="object")
2228+
e = Categorical(list('abc'))
2229+
f = Categorical(list('acb'))
2230+
2231+
# vs scalar
2232+
self.assertFalse((a=='a').all())
2233+
self.assertTrue(((a!='a') == ~(a=='a')).all())
2234+
2235+
self.assertFalse(('a'==a).all())
2236+
self.assertTrue((a=='a')[0])
2237+
self.assertTrue(('a'==a)[0])
2238+
self.assertFalse(('a'!=a)[0])
2239+
2240+
# vs list-like
2241+
self.assertTrue((a==a).all())
2242+
self.assertFalse((a!=a).all())
2243+
2244+
self.assertTrue((a==list(a)).all())
2245+
self.assertTrue((a==b).all())
2246+
self.assertTrue((b==a).all())
2247+
self.assertTrue(((~(a==b))==(a!=b)).all())
2248+
self.assertTrue(((~(b==a))==(b!=a)).all())
2249+
2250+
self.assertFalse((a==c).all())
2251+
self.assertFalse((c==a).all())
2252+
self.assertFalse((a==d).all())
2253+
self.assertFalse((d==a).all())
2254+
2255+
# vs a cat-like
2256+
self.assertTrue((a==e).all())
2257+
self.assertTrue((e==a).all())
2258+
self.assertFalse((a==f).all())
2259+
self.assertFalse((f==a).all())
2260+
2261+
self.assertTrue(((~(a==e)==(a!=e)).all()))
2262+
self.assertTrue(((~(e==a)==(e!=a)).all()))
2263+
self.assertTrue(((~(a==f)==(a!=f)).all()))
2264+
self.assertTrue(((~(f==a)==(f!=a)).all()))
2265+
2266+
# non-equality is not comparable
2267+
self.assertRaises(TypeError, lambda: a < b)
2268+
self.assertRaises(TypeError, lambda: b < a)
2269+
self.assertRaises(TypeError, lambda: a > b)
2270+
self.assertRaises(TypeError, lambda: b > a)
2271+
22202272
def test_concat(self):
22212273
cat = pd.Categorical(["a","b"], categories=["a","b"])
22222274
vals = [1,2]

0 commit comments

Comments
 (0)