Skip to content

Commit 3097224

Browse files
committed
Dispatch Series comparison ops to DatetimeIndex and TimedeltaIndex
1 parent e8620ab commit 3097224

File tree

4 files changed

+75
-25
lines changed

4 files changed

+75
-25
lines changed

pandas/core/ops.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,7 @@ def na_op(x, y):
741741
if is_categorical_dtype(x):
742742
return op(x, y)
743743
elif is_categorical_dtype(y) and not is_scalar(y):
744+
# the `not is_scalar(y)` check avoids catching string "category"
744745
return op(y, x)
745746

746747
elif is_object_dtype(x.dtype):
@@ -750,7 +751,6 @@ def na_op(x, y):
750751
raise TypeError("invalid type comparison")
751752

752753
else:
753-
754754
# we want to compare like types
755755
# we only want to convert to integer like if
756756
# we are not NotImplemented, otherwise
@@ -759,23 +759,18 @@ def na_op(x, y):
759759

760760
# we have a datetime/timedelta and may need to convert
761761
mask = None
762-
if (needs_i8_conversion(x) or
763-
(not is_scalar(y) and needs_i8_conversion(y))):
764-
765-
if is_scalar(y):
766-
mask = isna(x)
767-
y = libindex.convert_scalar(x, com._values_from_object(y))
768-
else:
769-
mask = isna(x) | isna(y)
770-
y = y.view('i8')
762+
if not is_scalar(y) and needs_i8_conversion(y):
763+
mask = isna(x) | isna(y)
764+
y = y.view('i8')
771765
x = x.view('i8')
772766

773-
try:
767+
method = getattr(x, name, None)
768+
if method is not None:
774769
with np.errstate(all='ignore'):
775770
result = getattr(x, name)(y)
776771
if result is NotImplemented:
777772
raise TypeError("invalid type comparison")
778-
except AttributeError:
773+
else:
779774
result = op(x, y)
780775

781776
if mask is not None and mask.any():
@@ -788,17 +783,36 @@ def wrapper(self, other, axis=None):
788783
if axis is not None:
789784
self._get_axis_number(axis)
790785

786+
res_name = _get_series_op_result_name(self, other)
787+
791788
if isinstance(other, ABCDataFrame): # pragma: no cover
792789
# Defer to DataFrame implementation; fail early
793790
return NotImplemented
794791

792+
elif isinstance(other, ABCSeries) and not self._indexed_same(other):
793+
raise ValueError('Can only compare identically-labeled Series '
794+
'objects')
795+
796+
elif is_datetime64_dtype(self) or is_datetime64tz_dtype(self):
797+
res_values = dispatch_to_index_op(op, self, other,
798+
pd.DatetimeIndex)
799+
return _construct_result(self, res_values,
800+
index=self.index, name=res_name,
801+
dtype=res_values.dtype)
802+
803+
elif is_timedelta64_dtype(self):
804+
res_values = dispatch_to_index_op(op, self, other,
805+
pd.TimedeltaIndex)
806+
return _construct_result(self, res_values,
807+
index=self.index, name=res_name,
808+
dtype=res_values.dtype)
809+
795810
elif isinstance(other, ABCSeries):
811+
# By this point we know that self._indexed_same(other)
796812
name = com._maybe_match_name(self, other)
797-
if not self._indexed_same(other):
798-
msg = 'Can only compare identically-labeled Series objects'
799-
raise ValueError(msg)
800813
res_values = na_op(self.values, other.values)
801-
return self._constructor(res_values, index=self.index, name=name)
814+
return self._constructor(res_values, index=self.index,
815+
name=res_name)
802816

803817
elif isinstance(other, (np.ndarray, pd.Index)):
804818
# do not check length of zerodim array
@@ -836,7 +850,7 @@ def wrapper(self, other, axis=None):
836850
res = op(self.values, other)
837851
else:
838852
values = self.get_values()
839-
if isinstance(other, (list, np.ndarray)):
853+
if isinstance(other, list):
840854
other = np.asarray(other)
841855

842856
with np.errstate(all='ignore'):

pandas/tests/indexes/datetimes/test_partial_slicing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def test_loc_datetime_length_one(self):
349349

350350
@pytest.mark.parametrize('datetimelike', [
351351
Timestamp('20130101'), datetime(2013, 1, 1),
352-
date(2013, 1, 1), np.datetime64('2013-01-01T00:00', 'ns')])
352+
np.datetime64('2013-01-01T00:00', 'ns')])
353353
@pytest.mark.parametrize('op,expected', [
354354
(op.lt, [True, False, False, False]),
355355
(op.le, [True, True, False, False]),

pandas/tests/series/test_arithmetic.py

+28
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,34 @@ def test_ser_flex_cmp_return_dtypes_empty(self, opname):
4343
result = getattr(empty, opname)(const).get_dtype_counts()
4444
tm.assert_series_equal(result, Series([1], ['bool']))
4545

46+
@pytest.mark.parametrize('op', [operator.eq, operator.ne,
47+
operator.le, operator.lt,
48+
operator.ge, operator.gt])
49+
@pytest.mark.parametrize('names', [(None, None, None),
50+
('foo', 'bar', None),
51+
('baz', 'baz', 'baz')])
52+
def test_ser_cmp_result_names(self, names, op):
53+
# so far only for timedelta, and datetime dtypes
54+
55+
# datetime64 dtype
56+
dti = pd.date_range('1949-06-07 03:00:00',
57+
freq='H', periods=5, name=names[0])
58+
ser = Series(dti).rename(names[1])
59+
result = op(ser, dti)
60+
assert result.name == names[2]
61+
62+
# datetime64tz dtype
63+
dti = dti.tz_localize('US/Central')
64+
ser = Series(dti).rename(names[1])
65+
result = op(ser, dti)
66+
assert result.name == names[2]
67+
68+
# timedelta64 dtype
69+
tdi = dti - dti.shift(1)
70+
ser = Series(tdi).rename(names[1])
71+
result = op(ser, tdi)
72+
assert result.name == names[2]
73+
4674

4775
class TestTimestampSeriesComparison(object):
4876
def test_dt64ser_cmp_period_scalar(self):

pandas/tests/test_base.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pandas as pd
1111
import pandas.compat as compat
1212
from pandas.core.dtypes.common import (
13-
is_object_dtype, is_datetimetz,
13+
is_object_dtype, is_datetimetz, is_datetime64_dtype,
1414
needs_i8_conversion)
1515
import pandas.util.testing as tm
1616
from pandas import (Series, Index, DatetimeIndex, TimedeltaIndex,
@@ -297,13 +297,21 @@ def test_none_comparison(self):
297297
# assert result.iat[0]
298298
# assert result.iat[1]
299299

300-
result = None > o
301-
assert not result.iat[0]
302-
assert not result.iat[1]
300+
if is_datetime64_dtype(o) or is_datetimetz(o):
301+
# datetime dtypes follow conventions set by
302+
# Timestamp (via DatetimeIndex)
303+
with pytest.raises(TypeError):
304+
None > o
305+
with pytest.raises(TypeError):
306+
o > None
307+
else:
308+
result = None > o
309+
assert not result.iat[0]
310+
assert not result.iat[1]
303311

304-
result = o < None
305-
assert not result.iat[0]
306-
assert not result.iat[1]
312+
result = o < None
313+
assert not result.iat[0]
314+
assert not result.iat[1]
307315

308316
def test_ndarray_compat_properties(self):
309317

0 commit comments

Comments
 (0)