Skip to content

Commit c4770c7

Browse files
jbrockmendeljreback
authored andcommitted
Dispatch categorical Series ops to Categorical (#19582)
1 parent c49cd54 commit c4770c7

File tree

6 files changed

+99
-38
lines changed

6 files changed

+99
-38
lines changed

doc/source/whatsnew/v0.23.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -849,3 +849,5 @@ Other
849849
^^^^^
850850

851851
- Improved error message when attempting to use a Python keyword as an identifier in a ``numexpr`` backed query (:issue:`18221`)
852+
- Comparisons between :class:`Series` and :class:`Index` would return a ``Series`` with an incorrect name, ignoring the ``Index``'s name attribute (:issue:`19582`)
853+
-

pandas/core/arrays/categorical.py

+3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def f(self, other):
5353
# results depending whether categories are the same or not is kind of
5454
# insane, so be a bit stricter here and use the python3 idea of
5555
# comparing only things of equal type.
56+
if isinstance(other, ABCSeries):
57+
return NotImplemented
58+
5659
if not self.ordered:
5760
if op in ['__lt__', '__gt__', '__le__', '__ge__']:
5861
raise TypeError("Unordered Categoricals can only compare "

pandas/core/indexes/category.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import operator
2+
13
import numpy as np
24
from pandas._libs import index as libindex
35

@@ -738,7 +740,9 @@ def _codes_for_groupby(self, sort):
738740
def _add_comparison_methods(cls):
739741
""" add in comparison methods """
740742

741-
def _make_compare(opname):
743+
def _make_compare(op):
744+
opname = '__{op}__'.format(op=op.__name__)
745+
742746
def _evaluate_compare(self, other):
743747

744748
# if we have a Categorical type, then must have the same
@@ -761,16 +765,21 @@ def _evaluate_compare(self, other):
761765
"have the same categories and ordered "
762766
"attributes")
763767

764-
return getattr(self.values, opname)(other)
768+
result = op(self.values, other)
769+
if isinstance(result, ABCSeries):
770+
# Dispatch to pd.Categorical returned NotImplemented
771+
# and we got a Series back; down-cast to ndarray
772+
result = result.values
773+
return result
765774

766775
return compat.set_function_name(_evaluate_compare, opname, cls)
767776

768-
cls.__eq__ = _make_compare('__eq__')
769-
cls.__ne__ = _make_compare('__ne__')
770-
cls.__lt__ = _make_compare('__lt__')
771-
cls.__gt__ = _make_compare('__gt__')
772-
cls.__le__ = _make_compare('__le__')
773-
cls.__ge__ = _make_compare('__ge__')
777+
cls.__eq__ = _make_compare(operator.eq)
778+
cls.__ne__ = _make_compare(operator.ne)
779+
cls.__lt__ = _make_compare(operator.lt)
780+
cls.__gt__ = _make_compare(operator.gt)
781+
cls.__le__ = _make_compare(operator.le)
782+
cls.__ge__ = _make_compare(operator.ge)
774783

775784
def _delegate_method(self, name, *args, **kwargs):
776785
""" method delegation to the ._values """

pandas/core/ops.py

+42-30
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def dispatch_to_index_op(op, left, right, index_class):
819819
# avoid accidentally allowing integer add/sub. For datetime64[tz] dtypes,
820820
# left_idx may inherit a freq from a cached DatetimeIndex.
821821
# See discussion in GH#19147.
822-
if left_idx.freq is not None:
822+
if getattr(left_idx, 'freq', None) is not None:
823823
left_idx = left_idx._shallow_copy(freq=None)
824824
try:
825825
result = op(left_idx, right)
@@ -867,9 +867,8 @@ def na_op(x, y):
867867

868868
# dispatch to the categorical if we have a categorical
869869
# in either operand
870-
if is_categorical_dtype(x):
871-
return op(x, y)
872-
elif is_categorical_dtype(y) and not is_scalar(y):
870+
if is_categorical_dtype(y) and not is_scalar(y):
871+
# The `not is_scalar(y)` check excludes the string "category"
873872
return op(y, x)
874873

875874
elif is_object_dtype(x.dtype):
@@ -917,17 +916,36 @@ def wrapper(self, other, axis=None):
917916
if axis is not None:
918917
self._get_axis_number(axis)
919918

919+
res_name = _get_series_op_result_name(self, other)
920+
920921
if isinstance(other, ABCDataFrame): # pragma: no cover
921922
# Defer to DataFrame implementation; fail early
922923
return NotImplemented
923924

925+
elif isinstance(other, ABCSeries) and not self._indexed_same(other):
926+
raise ValueError("Can only compare identically-labeled "
927+
"Series objects")
928+
929+
elif is_categorical_dtype(self):
930+
# Dispatch to Categorical implementation; pd.CategoricalIndex
931+
# behavior is non-canonical GH#19513
932+
res_values = dispatch_to_index_op(op, self, other, pd.Categorical)
933+
return self._constructor(res_values, index=self.index,
934+
name=res_name)
935+
936+
elif is_timedelta64_dtype(self):
937+
res_values = dispatch_to_index_op(op, self, other,
938+
pd.TimedeltaIndex)
939+
return self._constructor(res_values, index=self.index,
940+
name=res_name)
941+
924942
elif isinstance(other, ABCSeries):
925-
name = com._maybe_match_name(self, other)
926-
if not self._indexed_same(other):
927-
msg = 'Can only compare identically-labeled Series objects'
928-
raise ValueError(msg)
943+
# By this point we have checked that self._indexed_same(other)
929944
res_values = na_op(self.values, other.values)
930-
return self._constructor(res_values, index=self.index, name=name)
945+
# rename is needed in case res_name is None and res_values.name
946+
# is not.
947+
return self._constructor(res_values, index=self.index,
948+
name=res_name).rename(res_name)
931949

932950
elif isinstance(other, (np.ndarray, pd.Index)):
933951
# do not check length of zerodim array
@@ -937,15 +955,17 @@ def wrapper(self, other, axis=None):
937955
raise ValueError('Lengths must match to compare')
938956

939957
res_values = na_op(self.values, np.asarray(other))
940-
return self._constructor(res_values,
941-
index=self.index).__finalize__(self)
942-
943-
elif (isinstance(other, pd.Categorical) and
944-
not is_categorical_dtype(self)):
945-
raise TypeError("Cannot compare a Categorical for op {op} with "
946-
"Series of dtype {typ}.\nIf you want to compare "
947-
"values, use 'series <op> np.asarray(other)'."
948-
.format(op=op, typ=self.dtype))
958+
result = self._constructor(res_values, index=self.index)
959+
# rename is needed in case res_name is None and self.name
960+
# is not.
961+
return result.__finalize__(self).rename(res_name)
962+
963+
elif isinstance(other, pd.Categorical):
964+
# ordering of checks matters; by this point we know
965+
# that not is_categorical_dtype(self)
966+
res_values = op(self.values, other)
967+
return self._constructor(res_values, index=self.index,
968+
name=res_name)
949969

950970
elif is_scalar(other) and isna(other):
951971
# numpy does not like comparisons vs None
@@ -956,16 +976,9 @@ def wrapper(self, other, axis=None):
956976
return self._constructor(res_values, index=self.index,
957977
name=self.name, dtype='bool')
958978

959-
if is_categorical_dtype(self):
960-
# cats are a special case as get_values() would return an ndarray,
961-
# which would then not take categories ordering into account
962-
# we can go directly to op, as the na_op would just test again and
963-
# dispatch to it.
964-
with np.errstate(all='ignore'):
965-
res = op(self.values, other)
966979
else:
967980
values = self.get_values()
968-
if isinstance(other, (list, np.ndarray)):
981+
if isinstance(other, list):
969982
other = np.asarray(other)
970983

971984
with np.errstate(all='ignore'):
@@ -975,10 +988,9 @@ def wrapper(self, other, axis=None):
975988
.format(typ=type(other)))
976989

977990
# always return a full value series here
978-
res = com._values_from_object(res)
979-
980-
res = pd.Series(res, index=self.index, name=self.name, dtype='bool')
981-
return res
991+
res_values = com._values_from_object(res)
992+
return pd.Series(res_values, index=self.index,
993+
name=res_name, dtype='bool')
982994

983995
return wrapper
984996

pandas/tests/indexes/common.py

+1
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,7 @@ def test_equals_op(self):
790790
series_d = Series(array_d)
791791
with tm.assert_raises_regex(ValueError, "Lengths must match"):
792792
index_a == series_b
793+
793794
tm.assert_numpy_array_equal(index_a == series_a, expected1)
794795
tm.assert_numpy_array_equal(index_a == series_c, expected2)
795796

pandas/tests/series/test_arithmetic.py

+34
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,40 @@ def test_ser_flex_cmp_return_dtypes_empty(self, opname):
4343
result = getattr(empty, opname)(const).get_dtype_counts()
4444
tm.assert_series_equal(result, Series([1], ['bool']))
4545

46+
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
47+
operator.le, operator.lt,
48+
operator.ge, operator.gt])
49+
@pytest.mark.parametrize('names', [(None, None, None),
50+
('foo', 'bar', None),
51+
('baz', 'baz', 'baz')])
52+
def test_ser_cmp_result_names(self, names, op):
53+
# datetime64 dtype
54+
dti = pd.date_range('1949-06-07 03:00:00',
55+
freq='H', periods=5, name=names[0])
56+
ser = Series(dti).rename(names[1])
57+
result = op(ser, dti)
58+
assert result.name == names[2]
59+
60+
# datetime64tz dtype
61+
dti = dti.tz_localize('US/Central')
62+
ser = Series(dti).rename(names[1])
63+
result = op(ser, dti)
64+
assert result.name == names[2]
65+
66+
# timedelta64 dtype
67+
tdi = dti - dti.shift(1)
68+
ser = Series(tdi).rename(names[1])
69+
result = op(ser, tdi)
70+
assert result.name == names[2]
71+
72+
# categorical
73+
if op in [operator.eq, operator.ne]:
74+
# categorical dtype comparisons raise for inequalities
75+
cidx = tdi.astype('category')
76+
ser = Series(cidx).rename(names[1])
77+
result = op(ser, cidx)
78+
assert result.name == names[2]
79+
4680

4781
class TestTimestampSeriesComparison(object):
4882
def test_dt64ser_cmp_period_scalar(self):

0 commit comments

Comments
 (0)