Skip to content

Dispatch Series comparison ops to DatetimeIndex and TimedeltaIndex #19524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
52 changes: 32 additions & 20 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import numpy as np
import pandas as pd

from pandas._libs import (lib, index as libindex,
algos as libalgos)
from pandas._libs import lib, algos as libalgos

from pandas import compat
from pandas.util._decorators import Appender
Expand Down Expand Up @@ -741,6 +740,7 @@ def na_op(x, y):
if is_categorical_dtype(x):
return op(x, y)
elif is_categorical_dtype(y) and not is_scalar(y):
# the `not is_scalar(y)` check avoids catching string "category"
return op(y, x)

elif is_object_dtype(x.dtype):
Expand All @@ -750,7 +750,6 @@ def na_op(x, y):
raise TypeError("invalid type comparison")

else:

# we want to compare like types
# we only want to convert to integer like if
# we are not NotImplemented, otherwise
Expand All @@ -759,23 +758,18 @@ def na_op(x, y):

# we have a datetime/timedelta and may need to convert
mask = None
if (needs_i8_conversion(x) or
(not is_scalar(y) and needs_i8_conversion(y))):

if is_scalar(y):
mask = isna(x)
y = libindex.convert_scalar(x, com._values_from_object(y))
else:
mask = isna(x) | isna(y)
y = y.view('i8')
if not is_scalar(y) and needs_i8_conversion(y):
mask = isna(x) | isna(y)
y = y.view('i8')
x = x.view('i8')

try:
method = getattr(x, name, None)
if method is not None:
with np.errstate(all='ignore'):
result = getattr(x, name)(y)
if result is NotImplemented:
raise TypeError("invalid type comparison")
except AttributeError:
else:
result = op(x, y)

if mask is not None and mask.any():
Expand All @@ -788,17 +782,35 @@ def wrapper(self, other, axis=None):
if axis is not None:
self._get_axis_number(axis)

res_name = _get_series_op_result_name(self, other)

if isinstance(other, ABCDataFrame): # pragma: no cover
# Defer to DataFrame implementation; fail early
return NotImplemented

elif isinstance(other, ABCSeries) and not self._indexed_same(other):
raise ValueError('Can only compare identically-labeled Series '
'objects')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I prefer not to "orphan" words on their own line. Move the "Series" part of your sentence there too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change. Ditto comma above.


elif is_datetime64_dtype(self) or is_datetime64tz_dtype(self):
res_values = dispatch_to_index_op(op, self, other,
pd.DatetimeIndex)
return _construct_result(self, res_values,
index=self.index, name=res_name,
dtype=res_values.dtype)

elif is_timedelta64_dtype(self):
res_values = dispatch_to_index_op(op, self, other,
pd.TimedeltaIndex)
return _construct_result(self, res_values,
index=self.index, name=res_name,
dtype=res_values.dtype)

elif isinstance(other, ABCSeries):
name = com._maybe_match_name(self, other)
if not self._indexed_same(other):
msg = 'Can only compare identically-labeled Series objects'
raise ValueError(msg)
# By this point we know that self._indexed_same(other)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comma between "point" and "we"

res_values = na_op(self.values, other.values)
return self._constructor(res_values, index=self.index, name=name)
return self._constructor(res_values, index=self.index,
name=res_name)

elif isinstance(other, (np.ndarray, pd.Index)):
# do not check length of zerodim array
Expand Down Expand Up @@ -836,7 +848,7 @@ def wrapper(self, other, axis=None):
res = op(self.values, other)
else:
values = self.get_values()
if isinstance(other, (list, np.ndarray)):
if isinstance(other, list):
other = np.asarray(other)

with np.errstate(all='ignore'):
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/indexes/datetimes/test_partial_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from datetime import datetime, date
from datetime import datetime
import numpy as np
import pandas as pd
import operator as op
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_loc_datetime_length_one(self):

@pytest.mark.parametrize('datetimelike', [
Timestamp('20130101'), datetime(2013, 1, 1),
date(2013, 1, 1), np.datetime64('2013-01-01T00:00', 'ns')])
np.datetime64('2013-01-01T00:00', 'ns')])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you have to change the parametrization?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b/c ATM Series[datetime64].__cmp__(date) treats the date as a datetime, i.e. allows the comparison. But DatetimeIndex does not -- following convention set by Timestamp (and datetime itself). The DatetimeIndex behavior is canonical.

@pytest.mark.parametrize('op,expected', [
(op.lt, [True, False, False, False]),
(op.le, [True, True, False, False]),
Expand Down
28 changes: 28 additions & 0 deletions pandas/tests/series/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,34 @@ def test_ser_flex_cmp_return_dtypes_empty(self, opname):
result = getattr(empty, opname)(const).get_dtype_counts()
tm.assert_series_equal(result, Series([1], ['bool']))

@pytest.mark.parametrize('op', [operator.eq, operator.ne,
operator.le, operator.lt,
operator.ge, operator.gt])
@pytest.mark.parametrize('names', [(None, None, None),
('foo', 'bar', None),
('baz', 'baz', 'baz')])
def test_ser_cmp_result_names(self, names, op):
# so far only for timedelta, and datetime dtypes

# datetime64 dtype
dti = pd.date_range('1949-06-07 03:00:00',
freq='H', periods=5, name=names[0])
ser = Series(dti).rename(names[1])
result = op(ser, dti)
assert result.name == names[2]

# datetime64tz dtype
dti = dti.tz_localize('US/Central')
ser = Series(dti).rename(names[1])
result = op(ser, dti)
assert result.name == names[2]

# timedelta64 dtype
tdi = dti - dti.shift(1)
ser = Series(tdi).rename(names[1])
result = op(ser, tdi)
assert result.name == names[2]


class TestTimestampSeriesComparison(object):
def test_dt64ser_cmp_period_scalar(self):
Expand Down
22 changes: 15 additions & 7 deletions pandas/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
import pandas.compat as compat
from pandas.core.dtypes.common import (
is_object_dtype, is_datetimetz,
is_object_dtype, is_datetimetz, is_datetime64_dtype,
needs_i8_conversion)
import pandas.util.testing as tm
from pandas import (Series, Index, DatetimeIndex, TimedeltaIndex,
Expand Down Expand Up @@ -297,13 +297,21 @@ def test_none_comparison(self):
# assert result.iat[0]
# assert result.iat[1]

result = None > o
assert not result.iat[0]
assert not result.iat[1]
if is_datetime64_dtype(o) or is_datetimetz(o):
# datetime dtypes follow conventions set by
# Timestamp (via DatetimeIndex)
with pytest.raises(TypeError):
None > o
with pytest.raises(TypeError):
o > None
else:
result = None > o
assert not result.iat[0]
assert not result.iat[1]

result = o < None
assert not result.iat[0]
assert not result.iat[1]
result = o < None
assert not result.iat[0]
assert not result.iat[1]

def test_ndarray_compat_properties(self):

Expand Down