Skip to content

Commit f98e744

Browse files
committed
API: Allow equality comparisons of Series with a categorical dtype and object dtype are allowed (previously would raise TypeError) (GH8938)
1 parent 8290a4d commit f98e744

File tree

4 files changed

+87
-24
lines changed

4 files changed

+87
-24
lines changed

doc/source/whatsnew/v0.15.2.txt

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ API changes
5959
p = pd.Panel(np.random.rand(2, 5, 4) > 0.1)
6060
p.all()
6161

62+
- Allow equality comparisons of Series with a categorical dtype and object dtype; previously these would raise ``TypeError`` (:issue:`8938`)
63+
6264
.. _whatsnew_0152.enhancements:
6365

6466
Enhancements

pandas/core/categorical.py

+6
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def f(self, other):
6464
else:
6565
return np.repeat(False, len(self))
6666
else:
67+
68+
# allow categorical vs object dtype array comparisons for equality
69+
# these are only positional comparisons
70+
if op in ['__eq__','__ne__']:
71+
return getattr(np.array(self),op)(np.array(other))
72+
6773
msg = "Cannot compare a Categorical for op {op} with type {typ}. If you want to \n" \
6874
"compare values, use 'np.asarray(cat) <op> other'."
6975
raise TypeError(msg.format(op=op,typ=type(other)))

pandas/core/ops.py

+26-23
Original file line numberDiff line numberDiff line change
@@ -541,10 +541,13 @@ def _comp_method_SERIES(op, name, str_rep, masker=False):
541541
"""
542542
def na_op(x, y):
543543

544-
if com.is_categorical_dtype(x) != (not np.isscalar(y) and com.is_categorical_dtype(y)):
545-
msg = "Cannot compare a Categorical for op {op} with type {typ}. If you want to \n" \
546-
"compare values, use 'series <op> np.asarray(cat)'."
547-
raise TypeError(msg.format(op=op,typ=type(y)))
544+
# dispatch to the categorical if we have a categorical
545+
# in either operand
546+
if com.is_categorical_dtype(x):
547+
return op(x,y)
548+
elif com.is_categorical_dtype(y) and not lib.isscalar(y):
549+
return op(y,x)
550+
548551
if x.dtype == np.object_:
549552
if isinstance(y, list):
550553
y = lib.list_to_object_array(y)
@@ -586,33 +589,33 @@ def wrapper(self, other):
586589
msg = "Cannot compare a Categorical for op {op} with Series of dtype {typ}.\n"\
587590
"If you want to compare values, use 'series <op> np.asarray(other)'."
588591
raise TypeError(msg.format(op=op,typ=self.dtype))
589-
else:
590592

591-
mask = isnull(self)
592593

593-
values = self.get_values()
594-
other = _index.convert_scalar(values,_values_from_object(other))
594+
mask = isnull(self)
595595

596-
if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
597-
values = values.view('i8')
596+
values = self.get_values()
597+
other = _index.convert_scalar(values,_values_from_object(other))
598598

599-
# scalars
600-
res = na_op(values, other)
601-
if np.isscalar(res):
602-
raise TypeError('Could not compare %s type with Series'
603-
% type(other))
599+
if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
600+
values = values.view('i8')
604601

605-
# always return a full value series here
606-
res = _values_from_object(res)
602+
# scalars
603+
res = na_op(values, other)
604+
if np.isscalar(res):
605+
raise TypeError('Could not compare %s type with Series'
606+
% type(other))
607607

608-
res = pd.Series(res, index=self.index, name=self.name,
609-
dtype='bool')
608+
# always return a full value series here
609+
res = _values_from_object(res)
610610

611-
# mask out the invalids
612-
if mask.any():
613-
res[mask] = masker
611+
res = pd.Series(res, index=self.index, name=self.name,
612+
dtype='bool')
613+
614+
# mask out the invalids
615+
if mask.any():
616+
res[mask] = masker
614617

615-
return res
618+
return res
616619
return wrapper
617620

618621

pandas/tests/test_categorical.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -2211,11 +2211,63 @@ def f():
22112211
tm.assert_series_equal(res, exp)
22122212

22132213
# And test NaN handling...
2214-
cat = pd.Series(pd.Categorical(["a","b","c", np.nan]))
2214+
cat = Series(Categorical(["a","b","c", np.nan]))
22152215
exp = Series([True, True, True, False])
22162216
res = (cat == cat)
22172217
tm.assert_series_equal(res, exp)
22182218

2219+
def test_cat_equality(self):
2220+
2221+
# GH 8938
2222+
# allow equality comparisons
2223+
a = Series(list('abc'),dtype="category")
2224+
b = Series(list('abc'),dtype="object")
2225+
c = Series(['a','b','cc'],dtype="object")
2226+
d = Series(list('acb'),dtype="object")
2227+
e = Categorical(list('abc'))
2228+
f = Categorical(list('acb'))
2229+
2230+
# vs scalar
2231+
self.assertFalse((a=='a').all())
2232+
self.assertTrue(((a!='a') == ~(a=='a')).all())
2233+
2234+
self.assertFalse(('a'==a).all())
2235+
self.assertTrue((a=='a')[0])
2236+
self.assertTrue(('a'==a)[0])
2237+
self.assertFalse(('a'!=a)[0])
2238+
2239+
# vs list-like
2240+
self.assertTrue((a==a).all())
2241+
self.assertFalse((a!=a).all())
2242+
2243+
self.assertTrue((a==list(a)).all())
2244+
self.assertTrue((a==b).all())
2245+
self.assertTrue((b==a).all())
2246+
self.assertTrue(((~(a==b))==(a!=b)).all())
2247+
self.assertTrue(((~(b==a))==(b!=a)).all())
2248+
2249+
self.assertFalse((a==c).all())
2250+
self.assertFalse((c==a).all())
2251+
self.assertFalse((a==d).all())
2252+
self.assertFalse((d==a).all())
2253+
2254+
# vs a cat-like
2255+
self.assertTrue((a==e).all())
2256+
self.assertTrue((e==a).all())
2257+
self.assertFalse((a==f).all())
2258+
self.assertFalse((f==a).all())
2259+
2260+
self.assertTrue(((~(a==e)==(a!=e)).all()))
2261+
self.assertTrue(((~(e==a)==(e!=a)).all()))
2262+
self.assertTrue(((~(a==f)==(a!=f)).all()))
2263+
self.assertTrue(((~(f==a)==(f!=a)).all()))
2264+
2265+
# non-equality is not comparable
2266+
self.assertRaises(TypeError, lambda: a < b)
2267+
self.assertRaises(TypeError, lambda: b < a)
2268+
self.assertRaises(TypeError, lambda: a > b)
2269+
self.assertRaises(TypeError, lambda: b > a)
2270+
22192271
def test_concat(self):
22202272
cat = pd.Categorical(["a","b"], categories=["a","b"])
22212273
vals = [1,2]

0 commit comments

Comments
 (0)