Skip to content

Commit 6f5bab1

Browse files
jbrockmendelvictor
authored and
victor
committed
[REF] Move comparison methods to EAMixins, share code (pandas-dev#21872)
1 parent 76de7b4 commit 6f5bab1

File tree

14 files changed

+356
-206
lines changed

14 files changed

+356
-206
lines changed

pandas/_libs/src/numpy_helper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ PANDAS_INLINE PyObject* get_value_1d(PyArrayObject* ap, Py_ssize_t i) {
3232

3333

3434
void set_array_not_contiguous(PyArrayObject* ao) {
35-
ao->flags &= ~(NPY_C_CONTIGUOUS | NPY_F_CONTIGUOUS);
35+
ao->flags &= ~(NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS);
3636
}
3737

3838
#endif // PANDAS__LIBS_SRC_NUMPY_HELPER_H_

pandas/_libs/tslibs/period.pyx

+28-9
Original file line numberDiff line numberDiff line change
@@ -1859,21 +1859,40 @@ cdef int64_t _ordinal_from_fields(year, month, quarter, day,
18591859
hour, minute, second, freq):
18601860
base, mult = get_freq_code(freq)
18611861
if quarter is not None:
1862-
year, month = _quarter_to_myear(year, quarter, freq)
1862+
year, month = quarter_to_myear(year, quarter, freq)
18631863

18641864
return period_ordinal(year, month, day, hour,
18651865
minute, second, 0, 0, base)
18661866

18671867

1868-
def _quarter_to_myear(year, quarter, freq):
1869-
if quarter is not None:
1870-
if quarter <= 0 or quarter > 4:
1871-
raise ValueError('Quarter must be 1 <= q <= 4')
1868+
def quarter_to_myear(int year, int quarter, freq):
1869+
"""
1870+
A quarterly frequency defines a "year" which may not coincide with
1871+
the calendar-year. Find the calendar-year and calendar-month associated
1872+
with the given year and quarter under the `freq`-derived calendar.
1873+
1874+
Parameters
1875+
----------
1876+
year : int
1877+
quarter : int
1878+
freq : DateOffset
1879+
1880+
Returns
1881+
-------
1882+
year : int
1883+
month : int
1884+
1885+
See Also
1886+
--------
1887+
Period.qyear
1888+
"""
1889+
if quarter <= 0 or quarter > 4:
1890+
raise ValueError('Quarter must be 1 <= q <= 4')
18721891

1873-
mnum = MONTH_NUMBERS[get_rule_month(freq)] + 1
1874-
month = (mnum + (quarter - 1) * 3) % 12 + 1
1875-
if month > mnum:
1876-
year -= 1
1892+
mnum = MONTH_NUMBERS[get_rule_month(freq)] + 1
1893+
month = (mnum + (quarter - 1) * 3) % 12 + 1
1894+
if month > mnum:
1895+
year -= 1
18771896

18781897
return year, month
18791898

pandas/core/arrays/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from .datetimes import DatetimeArrayMixin # noqa
55
from .interval import IntervalArray # noqa
66
from .period import PeriodArrayMixin # noqa
7-
from .timedelta import TimedeltaArrayMixin # noqa
7+
from .timedeltas import TimedeltaArrayMixin # noqa

pandas/core/arrays/datetimelike.py

+116
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,53 @@
1010
DIFFERENT_FREQ_INDEX, IncompatibleFrequency)
1111

1212
from pandas.errors import NullFrequencyError, PerformanceWarning
13+
from pandas import compat
1314

1415
from pandas.tseries import frequencies
1516
from pandas.tseries.offsets import Tick
1617

1718
from pandas.core.dtypes.common import (
19+
needs_i8_conversion,
20+
is_list_like,
21+
is_bool_dtype,
1822
is_period_dtype,
1923
is_timedelta64_dtype,
2024
is_object_dtype)
25+
from pandas.core.dtypes.generic import ABCSeries, ABCDataFrame, ABCIndexClass
2126

2227
import pandas.core.common as com
2328
from pandas.core.algorithms import checked_add_with_arr
2429

2530

31+
def _make_comparison_op(op, cls):
32+
# TODO: share code with indexes.base version? Main difference is that
33+
# the block for MultiIndex was removed here.
34+
def cmp_method(self, other):
35+
if isinstance(other, ABCDataFrame):
36+
return NotImplemented
37+
38+
if isinstance(other, (np.ndarray, ABCIndexClass, ABCSeries)):
39+
if other.ndim > 0 and len(self) != len(other):
40+
raise ValueError('Lengths must match to compare')
41+
42+
if needs_i8_conversion(self) and needs_i8_conversion(other):
43+
# we may need to directly compare underlying
44+
# representations
45+
return self._evaluate_compare(other, op)
46+
47+
# numpy will show a DeprecationWarning on invalid elementwise
48+
# comparisons, this will raise in the future
49+
with warnings.catch_warnings(record=True):
50+
with np.errstate(all='ignore'):
51+
result = op(self.values, np.asarray(other))
52+
53+
return result
54+
55+
name = '__{name}__'.format(name=op.__name__)
56+
# TODO: docstring?
57+
return compat.set_function_name(cmp_method, name, cls)
58+
59+
2660
class AttributesMixin(object):
2761

2862
@property
@@ -435,3 +469,85 @@ def _addsub_offset_array(self, other, op):
435469
if not is_period_dtype(self):
436470
kwargs['freq'] = 'infer'
437471
return type(self)(res_values, **kwargs)
472+
473+
# --------------------------------------------------------------
474+
# Comparison Methods
475+
476+
def _evaluate_compare(self, other, op):
477+
"""
478+
We have been called because a comparison between
479+
8 aware arrays. numpy >= 1.11 will
480+
now warn about NaT comparisons
481+
"""
482+
# Called by comparison methods when comparing datetimelike
483+
# with datetimelike
484+
485+
if not isinstance(other, type(self)):
486+
# coerce to a similar object
487+
if not is_list_like(other):
488+
# scalar
489+
other = [other]
490+
elif lib.is_scalar(lib.item_from_zerodim(other)):
491+
# ndarray scalar
492+
other = [other.item()]
493+
other = type(self)(other)
494+
495+
# compare
496+
result = op(self.asi8, other.asi8)
497+
498+
# technically we could support bool dtyped Index
499+
# for now just return the indexing array directly
500+
mask = (self._isnan) | (other._isnan)
501+
502+
filler = iNaT
503+
if is_bool_dtype(result):
504+
filler = False
505+
506+
result[mask] = filler
507+
return result
508+
509+
# TODO: get this from ExtensionOpsMixin
510+
@classmethod
511+
def _add_comparison_methods(cls):
512+
""" add in comparison methods """
513+
# DatetimeArray and TimedeltaArray comparison methods will
514+
# call these as their super(...) methods
515+
cls.__eq__ = _make_comparison_op(operator.eq, cls)
516+
cls.__ne__ = _make_comparison_op(operator.ne, cls)
517+
cls.__lt__ = _make_comparison_op(operator.lt, cls)
518+
cls.__gt__ = _make_comparison_op(operator.gt, cls)
519+
cls.__le__ = _make_comparison_op(operator.le, cls)
520+
cls.__ge__ = _make_comparison_op(operator.ge, cls)
521+
522+
523+
DatetimeLikeArrayMixin._add_comparison_methods()
524+
525+
526+
# -------------------------------------------------------------------
527+
# Shared Constructor Helpers
528+
529+
def validate_periods(periods):
530+
"""
531+
If a `periods` argument is passed to the Datetime/Timedelta Array/Index
532+
constructor, cast it to an integer.
533+
534+
Parameters
535+
----------
536+
periods : None, float, int
537+
538+
Returns
539+
-------
540+
periods : None or int
541+
542+
Raises
543+
------
544+
TypeError
545+
if periods is None, float, or int
546+
"""
547+
if periods is not None:
548+
if lib.is_float(periods):
549+
periods = int(periods)
550+
elif not lib.is_integer(periods):
551+
raise TypeError('periods must be a number, got {periods}'
552+
.format(periods=periods))
553+
return periods

pandas/core/arrays/datetimes.py

+85-3
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,37 @@
1313

1414
from pandas.util._decorators import cache_readonly
1515
from pandas.errors import PerformanceWarning
16+
from pandas import compat
1617

1718
from pandas.core.dtypes.common import (
1819
_NS_DTYPE,
20+
is_datetimelike,
1921
is_datetime64tz_dtype,
2022
is_datetime64_dtype,
2123
is_timedelta64_dtype,
2224
_ensure_int64)
2325
from pandas.core.dtypes.dtypes import DatetimeTZDtype
26+
from pandas.core.dtypes.missing import isna
27+
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
2428

29+
import pandas.core.common as com
2530
from pandas.core.algorithms import checked_add_with_arr
2631

2732
from pandas.tseries.frequencies import to_offset, DateOffset
2833
from pandas.tseries.offsets import Tick
2934

30-
from .datetimelike import DatetimeLikeArrayMixin
35+
from pandas.core.arrays import datetimelike as dtl
36+
37+
38+
def _to_m8(key, tz=None):
39+
"""
40+
Timestamp-like => dt64
41+
"""
42+
if not isinstance(key, Timestamp):
43+
# this also converts strings
44+
key = Timestamp(key, tz=tz)
45+
46+
return np.int64(conversion.pydt_to_i8(key)).view(_NS_DTYPE)
3147

3248

3349
def _field_accessor(name, field, docstring=None):
@@ -68,7 +84,58 @@ def f(self):
6884
return property(f)
6985

7086

71-
class DatetimeArrayMixin(DatetimeLikeArrayMixin):
87+
def _dt_array_cmp(opname, cls):
88+
"""
89+
Wrap comparison operations to convert datetime-like to datetime64
90+
"""
91+
nat_result = True if opname == '__ne__' else False
92+
93+
def wrapper(self, other):
94+
meth = getattr(dtl.DatetimeLikeArrayMixin, opname)
95+
96+
if isinstance(other, (datetime, np.datetime64, compat.string_types)):
97+
if isinstance(other, datetime):
98+
# GH#18435 strings get a pass from tzawareness compat
99+
self._assert_tzawareness_compat(other)
100+
101+
other = _to_m8(other, tz=self.tz)
102+
result = meth(self, other)
103+
if isna(other):
104+
result.fill(nat_result)
105+
else:
106+
if isinstance(other, list):
107+
other = type(self)(other)
108+
elif not isinstance(other, (np.ndarray, ABCIndexClass, ABCSeries)):
109+
# Following Timestamp convention, __eq__ is all-False
110+
# and __ne__ is all True, others raise TypeError.
111+
if opname == '__eq__':
112+
return np.zeros(shape=self.shape, dtype=bool)
113+
elif opname == '__ne__':
114+
return np.ones(shape=self.shape, dtype=bool)
115+
raise TypeError('%s type object %s' %
116+
(type(other), str(other)))
117+
118+
if is_datetimelike(other):
119+
self._assert_tzawareness_compat(other)
120+
121+
result = meth(self, np.asarray(other))
122+
result = com._values_from_object(result)
123+
124+
# Make sure to pass an array to result[...]; indexing with
125+
# Series breaks with older version of numpy
126+
o_mask = np.array(isna(other))
127+
if o_mask.any():
128+
result[o_mask] = nat_result
129+
130+
if self.hasnans:
131+
result[self._isnan] = nat_result
132+
133+
return result
134+
135+
return compat.set_function_name(wrapper, opname, cls)
136+
137+
138+
class DatetimeArrayMixin(dtl.DatetimeLikeArrayMixin):
72139
"""
73140
Assumes that subclass __new__/__init__ defines:
74141
tz
@@ -222,6 +289,18 @@ def __iter__(self):
222289
# -----------------------------------------------------------------
223290
# Comparison Methods
224291

292+
@classmethod
293+
def _add_comparison_methods(cls):
294+
"""add in comparison methods"""
295+
cls.__eq__ = _dt_array_cmp('__eq__', cls)
296+
cls.__ne__ = _dt_array_cmp('__ne__', cls)
297+
cls.__lt__ = _dt_array_cmp('__lt__', cls)
298+
cls.__gt__ = _dt_array_cmp('__gt__', cls)
299+
cls.__le__ = _dt_array_cmp('__le__', cls)
300+
cls.__ge__ = _dt_array_cmp('__ge__', cls)
301+
# TODO: Some classes pass __eq__ while others pass operator.eq;
302+
# standardize this.
303+
225304
def _has_same_tz(self, other):
226305
zzone = self._timezone
227306

@@ -335,7 +414,7 @@ def _add_delta(self, delta):
335414
The result's name is set outside of _add_delta by the calling
336415
method (__add__ or __sub__)
337416
"""
338-
from pandas.core.arrays.timedelta import TimedeltaArrayMixin
417+
from pandas.core.arrays.timedeltas import TimedeltaArrayMixin
339418

340419
if isinstance(delta, (Tick, timedelta, np.timedelta64)):
341420
new_values = self._add_delta_td(delta)
@@ -1021,3 +1100,6 @@ def to_julian_date(self):
10211100
self.microsecond / 3600.0 / 1e+6 +
10221101
self.nanosecond / 3600.0 / 1e+9
10231102
) / 24.0)
1103+
1104+
1105+
DatetimeArrayMixin._add_comparison_methods()

pandas/core/arrays/period.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pandas._libs.tslib import NaT, iNaT
99
from pandas._libs.tslibs.period import (
1010
Period, IncompatibleFrequency, DIFFERENT_FREQ_INDEX,
11-
get_period_field_arr, period_asfreq_arr, _quarter_to_myear)
11+
get_period_field_arr, period_asfreq_arr)
1212
from pandas._libs.tslibs import period as libperiod
1313
from pandas._libs.tslibs.timedeltas import delta_to_nanoseconds
1414
from pandas._libs.tslibs.fields import isleapyear_arr
@@ -26,7 +26,7 @@
2626
from pandas.tseries import frequencies
2727
from pandas.tseries.offsets import Tick, DateOffset
2828

29-
from .datetimelike import DatetimeLikeArrayMixin
29+
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin
3030

3131

3232
def _field_accessor(name, alias, docstring=None):
@@ -466,7 +466,7 @@ def _range_from_fields(year=None, month=None, quarter=None, day=None,
466466

467467
year, quarter = _make_field_arrays(year, quarter)
468468
for y, q in compat.zip(year, quarter):
469-
y, m = _quarter_to_myear(y, q, freq)
469+
y, m = libperiod.quarter_to_myear(y, q, freq)
470470
val = libperiod.period_ordinal(y, m, 1, 1, 1, 1, 0, 0, base)
471471
ordinals.append(val)
472472
else:

0 commit comments

Comments
 (0)