Skip to content

[REF] Move comparison methods to EAMixins, share code #21872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/_libs/src/numpy_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ PANDAS_INLINE PyObject* char_to_string(const char* data) {


void set_array_not_contiguous(PyArrayObject* ao) {
ao->flags &= ~(NPY_C_CONTIGUOUS | NPY_F_CONTIGUOUS);
ao->flags &= ~(NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS);
}

#endif // PANDAS__LIBS_SRC_NUMPY_HELPER_H_
37 changes: 28 additions & 9 deletions pandas/_libs/tslibs/period.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1859,21 +1859,40 @@ cdef int64_t _ordinal_from_fields(year, month, quarter, day,
hour, minute, second, freq):
base, mult = get_freq_code(freq)
if quarter is not None:
year, month = _quarter_to_myear(year, quarter, freq)
year, month = quarter_to_myear(year, quarter, freq)

return period_ordinal(year, month, day, hour,
minute, second, 0, 0, base)


def _quarter_to_myear(year, quarter, freq):
if quarter is not None:
if quarter <= 0 or quarter > 4:
raise ValueError('Quarter must be 1 <= q <= 4')
def quarter_to_myear(int year, int quarter, freq):
"""
A quarterly frequency defines a "year" which may not coincide with
the calendar-year. Find the calendar-year and calendar-month associated
with the given year and quarter under the `freq`-derived calendar.

Parameters
----------
year : int
quarter : int
freq : DateOffset

Returns
-------
year : int
month : int

See Also
--------
Period.qyear
"""
if quarter <= 0 or quarter > 4:
raise ValueError('Quarter must be 1 <= q <= 4')

mnum = MONTH_NUMBERS[get_rule_month(freq)] + 1
month = (mnum + (quarter - 1) * 3) % 12 + 1
if month > mnum:
year -= 1
mnum = MONTH_NUMBERS[get_rule_month(freq)] + 1
month = (mnum + (quarter - 1) * 3) % 12 + 1
if month > mnum:
year -= 1

return year, month

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .categorical import Categorical # noqa
from .datetimes import DatetimeArrayMixin # noqa
from .period import PeriodArrayMixin # noqa
from .timedelta import TimedeltaArrayMixin # noqa
from .timedeltas import TimedeltaArrayMixin # noqa
116 changes: 116 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,53 @@
DIFFERENT_FREQ_INDEX, IncompatibleFrequency)

from pandas.errors import NullFrequencyError, PerformanceWarning
from pandas import compat

from pandas.tseries import frequencies
from pandas.tseries.offsets import Tick

from pandas.core.dtypes.common import (
needs_i8_conversion,
is_list_like,
is_bool_dtype,
is_period_dtype,
is_timedelta64_dtype,
is_object_dtype)
from pandas.core.dtypes.generic import ABCSeries, ABCDataFrame, ABCIndexClass

import pandas.core.common as com
from pandas.core.algorithms import checked_add_with_arr


def _make_comparison_op(op, cls):
# TODO: share code with indexes.base version? Main difference is that
# the block for MultiIndex was removed here.
def cmp_method(self, other):
if isinstance(other, ABCDataFrame):
return NotImplemented

if isinstance(other, (np.ndarray, ABCIndexClass, ABCSeries)):
if other.ndim > 0 and len(self) != len(other):
raise ValueError('Lengths must match to compare')

if needs_i8_conversion(self) and needs_i8_conversion(other):
# we may need to directly compare underlying
# representations
return self._evaluate_compare(other, op)

# numpy will show a DeprecationWarning on invalid elementwise
# comparisons, this will raise in the future
with warnings.catch_warnings(record=True):
with np.errstate(all='ignore'):
result = op(self.values, np.asarray(other))

return result

name = '__{name}__'.format(name=op.__name__)
# TODO: docstring?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you can.

return compat.set_function_name(cmp_method, name, cls)


class AttributesMixin(object):

@property
Expand Down Expand Up @@ -435,3 +469,85 @@ def _addsub_offset_array(self, other, op):
if not is_period_dtype(self):
kwargs['freq'] = 'infer'
return type(self)(res_values, **kwargs)

# --------------------------------------------------------------
# Comparison Methods

def _evaluate_compare(self, other, op):
"""
We have been called because a comparison between
8 aware arrays. numpy >= 1.11 will
Copy link
Member

@gfyoung gfyoung Jul 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"8 aware arrays" ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the existing docstring in indexes.datetimelike. I'm open to suggestions.

now warn about NaT comparisons
"""
# Called by comparison methods when comparing datetimelike
# with datetimelike

if not isinstance(other, type(self)):
# coerce to a similar object
if not is_list_like(other):
# scalar
other = [other]
elif lib.is_scalar(lib.item_from_zerodim(other)):
# ndarray scalar
other = [other.item()]
other = type(self)(other)

# compare
result = op(self.asi8, other.asi8)

# technically we could support bool dtyped Index
# for now just return the indexing array directly
mask = (self._isnan) | (other._isnan)

filler = iNaT
if is_bool_dtype(result):
filler = False

result[mask] = filler
return result

# TODO: get this from ExtensionOpsMixin
@classmethod
def _add_comparison_methods(cls):
""" add in comparison methods """
# DatetimeArray and TimedeltaArray comparison methods will
# call these as their super(...) methods
cls.__eq__ = _make_comparison_op(operator.eq, cls)
cls.__ne__ = _make_comparison_op(operator.ne, cls)
cls.__lt__ = _make_comparison_op(operator.lt, cls)
cls.__gt__ = _make_comparison_op(operator.gt, cls)
cls.__le__ = _make_comparison_op(operator.le, cls)
cls.__ge__ = _make_comparison_op(operator.ge, cls)


DatetimeLikeArrayMixin._add_comparison_methods()


# -------------------------------------------------------------------
# Shared Constructor Helpers

def validate_periods(periods):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

"""
If a `periods` argument is passed to the Datetime/Timedelta Array/Index
constructor, cast it to an integer.

Parameters
----------
periods : None, float, int

Returns
-------
periods : None or int

Raises
------
TypeError
if periods is None, float, or int
"""
if periods is not None:
if lib.is_float(periods):
periods = int(periods)
elif not lib.is_integer(periods):
raise TypeError('periods must be a number, got {periods}'
.format(periods=periods))
return periods
88 changes: 85 additions & 3 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,37 @@

from pandas.util._decorators import cache_readonly
from pandas.errors import PerformanceWarning
from pandas import compat

from pandas.core.dtypes.common import (
_NS_DTYPE,
is_datetimelike,
is_datetime64tz_dtype,
is_datetime64_dtype,
is_timedelta64_dtype,
_ensure_int64)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries

import pandas.core.common as com
from pandas.core.algorithms import checked_add_with_arr

from pandas.tseries.frequencies import to_offset, DateOffset
from pandas.tseries.offsets import Tick

from .datetimelike import DatetimeLikeArrayMixin
from pandas.core.arrays import datetimelike as dtl


def _to_m8(key, tz=None):
"""
Timestamp-like => dt64
"""
if not isinstance(key, Timestamp):
# this also converts strings
key = Timestamp(key, tz=tz)

return np.int64(conversion.pydt_to_i8(key)).view(_NS_DTYPE)


def _field_accessor(name, field, docstring=None):
Expand Down Expand Up @@ -68,7 +84,58 @@ def f(self):
return property(f)


class DatetimeArrayMixin(DatetimeLikeArrayMixin):
def _dt_array_cmp(opname, cls):
"""
Wrap comparison operations to convert datetime-like to datetime64
"""
nat_result = True if opname == '__ne__' else False

def wrapper(self, other):
meth = getattr(dtl.DatetimeLikeArrayMixin, opname)

if isinstance(other, (datetime, np.datetime64, compat.string_types)):
if isinstance(other, datetime):
# GH#18435 strings get a pass from tzawareness compat
self._assert_tzawareness_compat(other)

other = _to_m8(other, tz=self.tz)
result = meth(self, other)
if isna(other):
result.fill(nat_result)
else:
if isinstance(other, list):
other = type(self)(other)
elif not isinstance(other, (np.ndarray, ABCIndexClass, ABCSeries)):
# Following Timestamp convention, __eq__ is all-False
# and __ne__ is all True, others raise TypeError.
if opname == '__eq__':
return np.zeros(shape=self.shape, dtype=bool)
elif opname == '__ne__':
return np.ones(shape=self.shape, dtype=bool)
raise TypeError('%s type object %s' %
(type(other), str(other)))

if is_datetimelike(other):
self._assert_tzawareness_compat(other)

result = meth(self, np.asarray(other))
result = com._values_from_object(result)

# Make sure to pass an array to result[...]; indexing with
# Series breaks with older version of numpy
o_mask = np.array(isna(other))
if o_mask.any():
result[o_mask] = nat_result

if self.hasnans:
result[self._isnan] = nat_result

return result

return compat.set_function_name(wrapper, opname, cls)


class DatetimeArrayMixin(dtl.DatetimeLikeArrayMixin):
"""
Assumes that subclass __new__/__init__ defines:
tz
Expand Down Expand Up @@ -222,6 +289,18 @@ def __iter__(self):
# -----------------------------------------------------------------
# Comparison Methods

@classmethod
def _add_comparison_methods(cls):
"""add in comparison methods"""
cls.__eq__ = _dt_array_cmp('__eq__', cls)
cls.__ne__ = _dt_array_cmp('__ne__', cls)
cls.__lt__ = _dt_array_cmp('__lt__', cls)
cls.__gt__ = _dt_array_cmp('__gt__', cls)
cls.__le__ = _dt_array_cmp('__le__', cls)
cls.__ge__ = _dt_array_cmp('__ge__', cls)
# TODO: Some classes pass __eq__ while others pass operator.eq;
# standardize this.

def _has_same_tz(self, other):
zzone = self._timezone

Expand Down Expand Up @@ -335,7 +414,7 @@ def _add_delta(self, delta):
The result's name is set outside of _add_delta by the calling
method (__add__ or __sub__)
"""
from pandas.core.arrays.timedelta import TimedeltaArrayMixin
from pandas.core.arrays.timedeltas import TimedeltaArrayMixin

if isinstance(delta, (Tick, timedelta, np.timedelta64)):
new_values = self._add_delta_td(delta)
Expand Down Expand Up @@ -1021,3 +1100,6 @@ def to_julian_date(self):
self.microsecond / 3600.0 / 1e+6 +
self.nanosecond / 3600.0 / 1e+9
) / 24.0)


DatetimeArrayMixin._add_comparison_methods()
6 changes: 3 additions & 3 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pandas._libs.tslib import NaT, iNaT
from pandas._libs.tslibs.period import (
Period, IncompatibleFrequency, DIFFERENT_FREQ_INDEX,
get_period_field_arr, period_asfreq_arr, _quarter_to_myear)
get_period_field_arr, period_asfreq_arr)
from pandas._libs.tslibs import period as libperiod
from pandas._libs.tslibs.timedeltas import delta_to_nanoseconds
from pandas._libs.tslibs.fields import isleapyear_arr
Expand All @@ -26,7 +26,7 @@
from pandas.tseries import frequencies
from pandas.tseries.offsets import Tick, DateOffset

from .datetimelike import DatetimeLikeArrayMixin
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin


def _field_accessor(name, alias, docstring=None):
Expand Down Expand Up @@ -466,7 +466,7 @@ def _range_from_fields(year=None, month=None, quarter=None, day=None,

year, quarter = _make_field_arrays(year, quarter)
for y, q in compat.zip(year, quarter):
y, m = _quarter_to_myear(y, q, freq)
y, m = libperiod.quarter_to_myear(y, q, freq)
val = libperiod.period_ordinal(y, m, 1, 1, 1, 1, 0, 0, base)
ordinals.append(val)
else:
Expand Down
Loading