Skip to content

Commit f0ac930

Browse files
jankatinsjreback
authored andcommitted
Fix: unequal comparisons of categorical and scalar
Before, unequal comparisons were not checking the order of the categories. This was due to a conversion to an ndarray, which turned the comparison to one between ndarray and scalar, which of course has no categories to take into account. Also add test cases and remove the one which actually tested the wrong behaviour.
1 parent 700f6eb commit f0ac930

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

doc/source/whatsnew/v0.16.1.txt

+1
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,4 @@ Bug Fixes
120120

121121
- Bug in which ``SparseDataFrame`` could not take `nan` as a column name (:issue:`8822`)
122122

123+
- Bug in unequal comparisons between a ``Series`` of dtype `"category"` and a scalar (e.g. ``Series(Categorical(list("abc"), categories=list("cba"), ordered=True)) > "b"``, which wouldn't use the order of the categories but use the lexicographical order. (:issue:`9848`)

pandas/core/ops.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -594,20 +594,26 @@ def wrapper(self, other):
594594

595595
mask = isnull(self)
596596

597-
values = self.get_values()
598-
other = _index.convert_scalar(values,_values_from_object(other))
597+
if com.is_categorical_dtype(self):
598+
# cats are a special case as get_values() would return an ndarray, which would then
599+
# not take categories ordering into account
600+
# we can go directly to op, as the na_op would just test again and dispatch to it.
601+
res = op(self.values, other)
602+
else:
603+
values = self.get_values()
604+
other = _index.convert_scalar(values,_values_from_object(other))
599605

600-
if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
601-
values = values.view('i8')
606+
if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
607+
values = values.view('i8')
602608

603-
# scalars
604-
res = na_op(values, other)
605-
if np.isscalar(res):
606-
raise TypeError('Could not compare %s type with Series'
607-
% type(other))
609+
# scalars
610+
res = na_op(values, other)
611+
if np.isscalar(res):
612+
raise TypeError('Could not compare %s type with Series'
613+
% type(other))
608614

609-
# always return a full value series here
610-
res = _values_from_object(res)
615+
# always return a full value series here
616+
res = _values_from_object(res)
611617

612618
res = pd.Series(res, index=self.index, name=self.name,
613619
dtype='bool')

pandas/tests/test_categorical.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def f():
114114
Categorical([1,2], [1,2,np.nan, np.nan])
115115
self.assertRaises(ValueError, f)
116116

117+
# The default should be unordered
118+
c1 = Categorical(["a", "b", "c", "a"])
119+
self.assertFalse(c1.ordered)
117120

118121
# Categorical as input
119122
c1 = Categorical(["a", "b", "c", "a"])
@@ -367,6 +370,13 @@ def f():
367370
self.assertRaises(TypeError, lambda: a < cat)
368371
self.assertRaises(TypeError, lambda: a < cat_rev)
369372

373+
# Make sure that unequal comparison take the categories order in account
374+
cat_rev = pd.Categorical(list("abc"), categories=list("cba"), ordered=True)
375+
exp = np.array([True, False, False])
376+
res = cat_rev > "b"
377+
self.assert_numpy_array_equal(res, exp)
378+
379+
370380
def test_na_flags_int_categories(self):
371381
# #1457
372382

@@ -2390,6 +2400,18 @@ def test_comparisons(self):
23902400
exp = Series([False, False, True])
23912401
tm.assert_series_equal(res, exp)
23922402

2403+
scalar = base[1]
2404+
res = cat > scalar
2405+
exp = Series([False, False, True])
2406+
exp2 = cat.values > scalar
2407+
tm.assert_series_equal(res, exp)
2408+
tm.assert_numpy_array_equal(res.values, exp2)
2409+
res_rev = cat_rev > scalar
2410+
exp_rev = Series([True, False, False])
2411+
exp_rev2 = cat_rev.values > scalar
2412+
tm.assert_series_equal(res_rev, exp_rev)
2413+
tm.assert_numpy_array_equal(res_rev.values, exp_rev2)
2414+
23932415
# Only categories with same categories can be compared
23942416
def f():
23952417
cat > cat_rev
@@ -2408,9 +2430,16 @@ def f():
24082430
self.assertRaises(TypeError, lambda: a < cat)
24092431
self.assertRaises(TypeError, lambda: a < cat_rev)
24102432

2411-
# Categoricals can be compared to scalar values
2412-
res = cat_rev > base[0]
2413-
tm.assert_series_equal(res, exp)
2433+
# unequal comparison should raise for unordered cats
2434+
cat = Series(Categorical(list("abc")))
2435+
def f():
2436+
cat > "b"
2437+
self.assertRaises(TypeError, f)
2438+
cat = Series(Categorical(list("abc"), ordered=False))
2439+
def f():
2440+
cat > "b"
2441+
self.assertRaises(TypeError, f)
2442+
24142443

24152444
# And test NaN handling...
24162445
cat = Series(Categorical(["a","b","c", np.nan]))

0 commit comments

Comments
 (0)