diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 35b662eaae9a5..e2d0571405d80 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -5,7 +5,7 @@ import numpy as np -from pandas._libs import algos as libalgos, lib +from pandas._libs import algos as libalgos import pandas.compat as compat from pandas.compat import lzip, u from pandas.compat.numpy import function as nv @@ -23,7 +23,7 @@ is_timedelta64_dtype) from pandas.core.dtypes.dtypes import CategoricalDtype from pandas.core.dtypes.generic import ( - ABCCategoricalIndex, ABCDataFrame, ABCIndexClass, ABCSeries) + ABCCategoricalIndex, ABCIndexClass, ABCSeries) from pandas.core.dtypes.inference import is_hashable from pandas.core.dtypes.missing import isna, notna @@ -34,6 +34,7 @@ import pandas.core.common as com from pandas.core.config import get_option from pandas.core.missing import interpolate_2d +from pandas.core.ops import CompWrapper from pandas.core.sorting import nargsort from pandas.io.formats import console @@ -53,17 +54,13 @@ def _cat_compare_op(op): + @CompWrapper(inst_from_senior_cls=True, zerodim=True) def f(self, other): # On python2, you can usually compare any type to any type, and # Categoricals can be seen as a custom type, but having different # results depending whether categories are the same or not is kind of # insane, so be a bit stricter here and use the python3 idea of # comparing only things of equal type. - if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)): - return NotImplemented - - other = lib.item_from_zerodim(other) - if not self.ordered: if op in ['__lt__', '__gt__', '__le__', '__ge__']: raise TypeError("Unordered Categoricals can only compare " diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 73e799f9e0a36..991ae5aaeded5 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -32,6 +32,7 @@ from pandas.core.algorithms import ( checked_add_with_arr, take, unique1d, value_counts) import pandas.core.common as com +from pandas.core.ops import CompWrapper from pandas.tseries import frequencies from pandas.tseries.offsets import DateOffset, Tick @@ -982,14 +983,12 @@ def _add_timedeltalike_scalar(self, other): new_values = self._maybe_mask_results(new_values) return new_values.view('i8') + @CompWrapper(validate_len=True) def _add_delta_tdi(self, other): """ Add a delta of a TimedeltaIndex return the i8 result view """ - if len(self) != len(other): - raise ValueError("cannot add indices of unequal length") - if isinstance(other, np.ndarray): # ndarray[timedelta64]; wrap in TimedeltaIndex for op from pandas import TimedeltaIndex @@ -1034,6 +1033,7 @@ def _sub_nat(self): result.fill(iNaT) return result.view('timedelta64[ns]') + @CompWrapper(validate_len=True) def _sub_period_array(self, other): """ Subtract a Period Array/Index from self. This is only valid if self @@ -1054,9 +1054,6 @@ def _sub_period_array(self, other): .format(dtype=other.dtype, cls=type(self).__name__)) - if len(self) != len(other): - raise ValueError("cannot subtract arrays/indices of " - "unequal length") if self.freq != other.freq: msg = DIFFERENT_FREQ.format(cls=type(self).__name__, own_freq=self.freqstr, @@ -1143,7 +1140,7 @@ def _time_shift(self, periods, freq=None): Note this is different from ExtensionArray.shift, which shifts the *position* of each element, padding the end with - missing values. + missing values.x Parameters ---------- @@ -1175,8 +1172,8 @@ def _time_shift(self, periods, freq=None): return self._generate_range(start=start, end=end, periods=None, freq=self.freq) + @CompWrapper(zerodim=True) def __add__(self, other): - other = lib.item_from_zerodim(other) if isinstance(other, (ABCSeries, ABCDataFrame)): return NotImplemented @@ -1238,8 +1235,8 @@ def __radd__(self, other): # alias for __add__ return self.__add__(other) + @CompWrapper(zerodim=True) def __sub__(self, other): - other = lib.item_from_zerodim(other) if isinstance(other, (ABCSeries, ABCDataFrame)): return NotImplemented diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index d7a8417a71be2..0d3be76b93620 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -20,8 +20,7 @@ is_extension_type, is_float_dtype, is_object_dtype, is_period_dtype, is_string_dtype, is_timedelta64_dtype, pandas_dtype) from pandas.core.dtypes.dtypes import DatetimeTZDtype -from pandas.core.dtypes.generic import ( - ABCDataFrame, ABCIndexClass, ABCPandasArray, ABCSeries) +from pandas.core.dtypes.generic import ABCIndexClass, ABCPandasArray, ABCSeries from pandas.core.dtypes.missing import isna from pandas.core import ops @@ -29,6 +28,7 @@ from pandas.core.arrays import datetimelike as dtl from pandas.core.arrays._ranges import generate_regular_range import pandas.core.common as com +from pandas.core.ops import CompWrapper from pandas.tseries.frequencies import get_period_alias, to_offset from pandas.tseries.offsets import Day, Tick @@ -130,12 +130,8 @@ def _dt_array_cmp(cls, op): opname = '__{name}__'.format(name=op.__name__) nat_result = True if opname == '__ne__' else False + @CompWrapper(inst_from_senior_cls=True, validate_len=True, zerodim=True) def wrapper(self, other): - if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)): - return NotImplemented - - other = lib.item_from_zerodim(other) - if isinstance(other, (datetime, np.datetime64, compat.string_types)): if isinstance(other, (datetime, np.datetime64)): # GH#18435 strings get a pass from tzawareness compat @@ -152,8 +148,6 @@ def wrapper(self, other): result.fill(nat_result) elif lib.is_scalar(other) or np.ndim(other) == 0: return ops.invalid_comparison(self, other, op) - elif len(other) != len(self): - raise ValueError("Lengths must match") else: if isinstance(other, list): try: @@ -703,11 +697,9 @@ def _assert_tzawareness_compat(self, other): # ----------------------------------------------------------------- # Arithmetic Methods + @CompWrapper(validate_len=True) def _sub_datetime_arraylike(self, other): """subtract DatetimeArray/Index or ndarray[datetime64]""" - if len(self) != len(other): - raise ValueError("cannot add indices of unequal length") - if isinstance(other, np.ndarray): assert is_datetime64_dtype(other) other = type(self)(other) diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index e0c71b5609096..331d0a190c4e8 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -16,15 +16,15 @@ from pandas.core.dtypes.common import ( _TD_DTYPE, ensure_object, is_datetime64_dtype, is_float_dtype, - is_list_like, is_period_dtype, pandas_dtype) + is_period_dtype, pandas_dtype) from pandas.core.dtypes.dtypes import PeriodDtype -from pandas.core.dtypes.generic import ( - ABCDataFrame, ABCIndexClass, ABCPeriodIndex, ABCSeries) +from pandas.core.dtypes.generic import ABCIndexClass, ABCPeriodIndex, ABCSeries from pandas.core.dtypes.missing import isna, notna import pandas.core.algorithms as algos from pandas.core.arrays import datetimelike as dtl import pandas.core.common as com +from pandas.core.ops import CompWrapper from pandas.tseries import frequencies from pandas.tseries.offsets import DateOffset, Tick, _delta_to_tick @@ -48,15 +48,10 @@ def _period_array_cmp(cls, op): opname = '__{name}__'.format(name=op.__name__) nat_result = True if opname == '__ne__' else False + @CompWrapper(validate_len=True, inst_from_senior_cls=True) def wrapper(self, other): op = getattr(self.asi8, opname) - if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)): - return NotImplemented - - if is_list_like(other) and len(other) != len(self): - raise ValueError("Lengths must match") - if isinstance(other, Period): self._check_compatible_with(other) diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index 4f0c96f7927da..86f7e9a26a9bb 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -28,6 +28,7 @@ from pandas.core import ops from pandas.core.algorithms import checked_add_with_arr import pandas.core.common as com +from pandas.core.ops import CompWrapper from pandas.tseries.frequencies import to_offset from pandas.tseries.offsets import Tick @@ -64,10 +65,8 @@ def _td_array_cmp(cls, op): opname = '__{name}__'.format(name=op.__name__) nat_result = True if opname == '__ne__' else False + @CompWrapper(validate_len=True, inst_from_senior_cls=True) def wrapper(self, other): - if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)): - return NotImplemented - if _is_convertible_to_td(other) or other is NaT: try: other = Timedelta(other) @@ -82,9 +81,6 @@ def wrapper(self, other): elif not is_list_like(other): return ops.invalid_comparison(self, other, op) - elif len(other) != len(self): - raise ValueError("Lengths must match") - else: try: other = type(self)._from_sequence(other)._data diff --git a/pandas/core/ops.py b/pandas/core/ops.py index 10cebc6f94b92..ac1b5711a5627 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -7,6 +7,7 @@ from __future__ import division import datetime +from functools import wraps import operator import textwrap import warnings @@ -28,8 +29,8 @@ is_integer_dtype, is_list_like, is_object_dtype, is_period_dtype, is_scalar, is_timedelta64_dtype, needs_i8_conversion) from pandas.core.dtypes.generic import ( - ABCDataFrame, ABCIndex, ABCIndexClass, ABCPanel, ABCSeries, ABCSparseArray, - ABCSparseSeries) + ABCDataFrame, ABCExtensionArray, ABCIndex, ABCIndexClass, ABCPanel, + ABCSeries, ABCSparseArray, ABCSparseSeries) from pandas.core.dtypes.missing import isna, notna import pandas as pd @@ -136,6 +137,62 @@ def maybe_upcast_for_op(obj): return obj +class CompWrapper(object): + __key__ = ['list_to_array', 'validate_len', + 'zerodim', 'inst_from_senior_cls'] + + def __init__(self, + list_to_array=None, + validate_len=None, + zerodim=None, + inst_from_senior_cls=None): + self.list_to_array = list_to_array + self.validate_len = validate_len + self.zerodim = zerodim + self.inst_from_senior_cls = inst_from_senior_cls + + def _list_to_array(self, comp): + @wraps(comp) + def wrapper(comp_self, comp_other): + if is_list_like(comp_other): + comp_other = np.asarray(comp_other) + return comp(comp_self, comp_other) + return wrapper + + def _validate_len(self, comp): + @wraps(comp) + def wrapper(comp_self, comp_other): + if is_list_like(comp_other) and len(comp_other) != len(comp_self): + raise ValueError("Lengths must match to compare") + return comp(comp_self, comp_other) + return wrapper + + def _zerodim(self, comp): + @wraps(comp) + def wrapper(comp_self, comp_other): + from pandas._libs import lib + comp_other = lib.item_from_zerodim(comp_other) + return comp(comp_self, comp_other) + return wrapper + + def _inst_from_senior_cls(self, comp): + @wraps(comp) + def wrapper(comp_self, comp_other): + if isinstance(comp_self, ABCExtensionArray): + if isinstance(comp_other, (ABCDataFrame, ABCSeries, + ABCIndexClass)): + # Rely on pandas to unbox and dispatch to us. + return NotImplemented + return comp(comp_self, comp_other) + return wrapper + + def __call__(self, comp): + for key in CompWrapper.__key__: + if getattr(self, key) is True: + comp = getattr(self, '_' + key)(comp) + return comp + + # ----------------------------------------------------------------------------- # Reversed Operations not available in the stdlib operator module. # Defining these instead of using lambdas allows us to reference them by name. diff --git a/pandas/tests/arithmetic/test_datetime64.py b/pandas/tests/arithmetic/test_datetime64.py index 405dc0805a285..99f687dbb485b 100644 --- a/pandas/tests/arithmetic/test_datetime64.py +++ b/pandas/tests/arithmetic/test_datetime64.py @@ -2091,7 +2091,7 @@ def test_sub_dti_dti(self): # different length raises ValueError dti1 = date_range('20130101', periods=3) dti2 = date_range('20130101', periods=4) - msg = 'cannot add indices of unequal length' + msg = 'Lengths must match to compare' with pytest.raises(ValueError, match=msg): dti1 - dti2