From 7d3d44580705ad92af7e31975eab10a6fb47e125 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 30 Apr 2020 18:59:39 -0700 Subject: [PATCH 1/5] REF: implement _validate_listlike --- pandas/core/arrays/datetimelike.py | 109 ++++++++++++++--------------- 1 file changed, 52 insertions(+), 57 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 8b6ed002b3f47..c348ca84083b8 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -84,24 +84,10 @@ def _validate_comparison_value(self, other): raise ValueError("Lengths must match") else: - if isinstance(other, list): - # TODO: could use pd.Index to do inference? - other = np.array(other) - - if not isinstance(other, (np.ndarray, type(self))): - raise InvalidComparison(other) - - elif is_object_dtype(other.dtype): - pass - - elif not type(self)._is_recognized_dtype(other.dtype): - raise InvalidComparison(other) - - else: - # For PeriodDType this casting is unnecessary - # TODO: use Index to do inference? - other = type(self)._from_sequence(other) - self._check_compatible_with(other) + try: + other = self._validate_listlike(other, allow_object=True) + except TypeError as err: + raise InvalidComparison(other) from err return other @@ -126,6 +112,7 @@ def wrapper(self, other): other_i8 = self._unbox_scalar(other) else: # Then type(other) == type(self) + self._check_compatible_with(other) other_i8 = other.asi8 result = op(self.asi8, other_i8) @@ -767,6 +754,47 @@ def _validate_shift_value(self, fill_value): fill_value = self._unbox_scalar(fill_value) return fill_value + def _validate_listlike( + self, + value, + cast_str: bool = False, + cast_cat: bool = False, + allow_object: bool = False, + ): + if isinstance(value, type(self)): + return value + + # Do type inference if necessary up front + # e.g. we passed PeriodIndex.values and got an ndarray of Periods + value = array(value) + value = extract_array(value, extract_numpy=True) + + if is_dtype_equal(value.dtype, "string") and cast_str: + # We got a StringArray + try: + # TODO: Could use from_sequence_of_strings if implemented + # Note: passing dtype is necessary for PeriodArray tests + value = type(self)._from_sequence(value, dtype=self.dtype) + except ValueError: + pass + + if is_categorical_dtype(value.dtype) and cast_cat: + # e.g. we have a Categorical holding self.dtype + if is_dtype_equal(value.categories.dtype, self.dtype): + # TODO: do we need equal dtype or just comparable? + value = value._internal_get_values() + + if allow_object and is_object_dtype(value.dtype): + pass + + elif not type(self)._is_recognized_dtype(value): + raise TypeError( # FIXME: dont be searchsorted-specific + "searchsorted requires compatible dtype or scalar, " + f"not {type(value).__name__}" + ) + + return value + def _validate_searchsorted_value(self, value): if isinstance(value, str): try: @@ -782,20 +810,12 @@ def _validate_searchsorted_value(self, value): elif isinstance(value, self._recognized_scalars): value = self._scalar_type(value) - elif isinstance(value, type(self)): - pass - - elif is_list_like(value) and not isinstance(value, type(self)): - value = array(value) - - if not type(self)._is_recognized_dtype(value): - raise TypeError( - "searchsorted requires compatible dtype or scalar, " - f"not {type(value).__name__}" - ) + elif not is_list_like(value): + raise TypeError(f"Unexpected type for 'value': {type(value)}") else: - raise TypeError(f"Unexpected type for 'value': {type(value)}") + # TODO: cast_str? we accept it for scalar + value = self._validate_listlike(value) if isinstance(value, type(self)): self._check_compatible_with(value) @@ -808,21 +828,7 @@ def _validate_searchsorted_value(self, value): def _validate_setitem_value(self, value): if is_list_like(value): - value = array(value) - if is_dtype_equal(value.dtype, "string"): - # We got a StringArray - try: - # TODO: Could use from_sequence_of_strings if implemented - # Note: passing dtype is necessary for PeriodArray tests - value = type(self)._from_sequence(value, dtype=self.dtype) - except ValueError: - pass - - if not type(self)._is_recognized_dtype(value): - raise TypeError( - "setitem requires compatible dtype or scalar, " - f"not {type(value).__name__}" - ) + value = self._validate_listlike(value, cast_str=True) elif isinstance(value, self._recognized_scalars): value = self._scalar_type(value) @@ -869,18 +875,7 @@ def _validate_where_value(self, other): raise TypeError(f"Where requires matching dtype, not {type(other)}") else: - # Do type inference if necessary up front - # e.g. we passed PeriodIndex.values and got an ndarray of Periods - other = array(other) - other = extract_array(other, extract_numpy=True) - - if is_categorical_dtype(other.dtype): - # e.g. we have a Categorical holding self.dtype - if is_dtype_equal(other.categories.dtype, self.dtype): - other = other._internal_get_values() - - if not type(self)._is_recognized_dtype(other.dtype): - raise TypeError(f"Where requires matching dtype, not {other.dtype}") + other = self._validate_listlike(other, cast_cat=True) self._check_compatible_with(other, setitem=True) if lib.is_scalar(other): From 278fca687808bf54b1800118a182fa14d21a121d Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 30 Apr 2020 19:22:41 -0700 Subject: [PATCH 2/5] generic exception message --- pandas/core/arrays/datetimelike.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index c348ca84083b8..ad9258327f04a 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -788,8 +788,8 @@ def _validate_listlike( pass elif not type(self)._is_recognized_dtype(value): - raise TypeError( # FIXME: dont be searchsorted-specific - "searchsorted requires compatible dtype or scalar, " + raise TypeError( + "Operation requires compatible dtype or scalar, " f"not {type(value).__name__}" ) From 87848a46641379ff9a9254200232cc4eb5429b8a Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 30 Apr 2020 20:02:19 -0700 Subject: [PATCH 3/5] method-specific exception messages --- pandas/core/arrays/datetimelike.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index ad9258327f04a..a8c270981398e 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -85,7 +85,7 @@ def _validate_comparison_value(self, other): else: try: - other = self._validate_listlike(other, allow_object=True) + other = self._validate_listlike(other, opname, allow_object=True) except TypeError as err: raise InvalidComparison(other) from err @@ -757,6 +757,7 @@ def _validate_shift_value(self, fill_value): def _validate_listlike( self, value, + opname: str, cast_str: bool = False, cast_cat: bool = False, allow_object: bool = False, @@ -769,7 +770,7 @@ def _validate_listlike( value = array(value) value = extract_array(value, extract_numpy=True) - if is_dtype_equal(value.dtype, "string") and cast_str: + if cast_str and is_dtype_equal(value.dtype, "string"): # We got a StringArray try: # TODO: Could use from_sequence_of_strings if implemented @@ -778,7 +779,7 @@ def _validate_listlike( except ValueError: pass - if is_categorical_dtype(value.dtype) and cast_cat: + if cast_cat and is_categorical_dtype(value.dtype): # e.g. we have a Categorical holding self.dtype if is_dtype_equal(value.categories.dtype, self.dtype): # TODO: do we need equal dtype or just comparable? @@ -789,7 +790,7 @@ def _validate_listlike( elif not type(self)._is_recognized_dtype(value): raise TypeError( - "Operation requires compatible dtype or scalar, " + f"{opname} requires compatible dtype or scalar, " f"not {type(value).__name__}" ) @@ -815,7 +816,7 @@ def _validate_searchsorted_value(self, value): else: # TODO: cast_str? we accept it for scalar - value = self._validate_listlike(value) + value = self._validate_listlike(value, "searchsorted") if isinstance(value, type(self)): self._check_compatible_with(value) @@ -828,7 +829,7 @@ def _validate_searchsorted_value(self, value): def _validate_setitem_value(self, value): if is_list_like(value): - value = self._validate_listlike(value, cast_str=True) + value = self._validate_listlike(value, "setitem", cast_str=True) elif isinstance(value, self._recognized_scalars): value = self._scalar_type(value) @@ -875,7 +876,7 @@ def _validate_where_value(self, other): raise TypeError(f"Where requires matching dtype, not {type(other)}") else: - other = self._validate_listlike(other, cast_cat=True) + other = self._validate_listlike(other, "where", cast_cat=True) self._check_compatible_with(other, setitem=True) if lib.is_scalar(other): From b2733fd3d4f286e86e050a780aca3ca34eda94a9 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Fri, 1 May 2020 07:22:24 -0700 Subject: [PATCH 4/5] mypy fixup --- pandas/core/arrays/datetimelike.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index a8c270981398e..d88e6c84877d1 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta import operator -from typing import Any, Sequence, Type, TypeVar, Union, cast +from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast import warnings import numpy as np @@ -442,6 +442,8 @@ class DatetimeLikeArrayMixin( _generate_range """ + _is_recognized_dtype: Callable[["DatetimeLikeArrayMixin"], bool] + # ------------------------------------------------------------------ # NDArrayBackedExtensionArray compat From c8b7afb92b01ff31d97a50ec8e97e87d899ba393 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Fri, 1 May 2020 09:38:17 -0700 Subject: [PATCH 5/5] mypy fixup --- pandas/core/arrays/datetimelike.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index d88e6c84877d1..c907c07b77315 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -10,7 +10,7 @@ from pandas._libs.tslibs.period import DIFFERENT_FREQ, IncompatibleFrequency, Period from pandas._libs.tslibs.timedeltas import Timedelta, delta_to_nanoseconds from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64 -from pandas._typing import DatetimeLikeScalar +from pandas._typing import DatetimeLikeScalar, DtypeObj from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning @@ -442,7 +442,7 @@ class DatetimeLikeArrayMixin( _generate_range """ - _is_recognized_dtype: Callable[["DatetimeLikeArrayMixin"], bool] + _is_recognized_dtype: Callable[[DtypeObj], bool] # ------------------------------------------------------------------ # NDArrayBackedExtensionArray compat