diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 8b6ed002b3f47..fbaa4e36d14c1 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -8,7 +8,7 @@ from pandas._libs import NaT, NaTType, Timestamp, algos, iNaT, lib from pandas._libs.tslibs.c_timestamp import integer_op_not_supported from pandas._libs.tslibs.period import DIFFERENT_FREQ, IncompatibleFrequency, Period -from pandas._libs.tslibs.timedeltas import Timedelta, delta_to_nanoseconds +from pandas._libs.tslibs.timedeltas import delta_to_nanoseconds from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64 from pandas._typing import DatetimeLikeScalar from pandas.compat import set_function_name @@ -52,6 +52,8 @@ from pandas.tseries import frequencies from pandas.tseries.offsets import DateOffset, Tick +DTScalarOrNaT = Union[DatetimeLikeScalar, NaTType] + def _datetimelike_array_cmp(cls, op): """ @@ -122,12 +124,7 @@ def wrapper(self, other): result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other) return result - if isinstance(other, self._scalar_type) or other is NaT: - other_i8 = self._unbox_scalar(other) - else: - # Then type(other) == type(self) - other_i8 = other.asi8 - + other_i8 = self._unbox(other) result = op(self.asi8, other_i8) o_mask = isna(other) @@ -157,9 +154,7 @@ def _scalar_type(self) -> Type[DatetimeLikeScalar]: """ raise AbstractMethodError(self) - def _scalar_from_string( - self, value: str - ) -> Union[Period, Timestamp, Timedelta, NaTType]: + def _scalar_from_string(self, value: str) -> DTScalarOrNaT: """ Construct a scalar type from a string. @@ -179,13 +174,14 @@ def _scalar_from_string( """ raise AbstractMethodError(self) - def _unbox_scalar(self, value: Union[Period, Timestamp, Timedelta, NaTType]) -> int: + def _unbox_scalar(self, value: DTScalarOrNaT) -> int: """ Unbox the integer value of a scalar `value`. Parameters ---------- - value : Union[Period, Timestamp, Timedelta] + value : Period, Timestamp, Timedelta, or NaT + Depending on subclass. Returns ------- @@ -199,7 +195,7 @@ def _unbox_scalar(self, value: Union[Period, Timestamp, Timedelta, NaTType]) -> raise AbstractMethodError(self) def _check_compatible_with( - self, other: Union[Period, Timestamp, Timedelta, NaTType], setitem: bool = False + self, other: DTScalarOrNaT, setitem: bool = False ) -> None: """ Verify that `self` and `other` are compatible. @@ -727,17 +723,16 @@ def _validate_fill_value(self, fill_value): ValueError """ if is_valid_nat_for_dtype(fill_value, self.dtype): - fill_value = iNaT + fill_value = NaT elif isinstance(fill_value, self._recognized_scalars): - self._check_compatible_with(fill_value) fill_value = self._scalar_type(fill_value) - fill_value = self._unbox_scalar(fill_value) else: raise ValueError( f"'fill_value' should be a {self._scalar_type}. " f"Got '{str(fill_value)}'." ) - return fill_value + + return self._unbox(fill_value) def _validate_shift_value(self, fill_value): # TODO(2.0): once this deprecation is enforced, use _validate_fill_value @@ -764,8 +759,7 @@ def _validate_shift_value(self, fill_value): ) fill_value = new_fill - fill_value = self._unbox_scalar(fill_value) - return fill_value + return self._unbox(fill_value) def _validate_searchsorted_value(self, value): if isinstance(value, str): @@ -797,13 +791,7 @@ def _validate_searchsorted_value(self, value): else: raise TypeError(f"Unexpected type for 'value': {type(value)}") - if isinstance(value, type(self)): - self._check_compatible_with(value) - value = value.asi8 - else: - value = self._unbox_scalar(value) - - return value + return self._unbox(value) def _validate_setitem_value(self, value): @@ -836,19 +824,11 @@ def _validate_setitem_value(self, value): raise TypeError(msg) self._check_compatible_with(value, setitem=True) - if isinstance(value, type(self)): - value = value.asi8 - else: - value = self._unbox_scalar(value) - - return value + return self._unbox(value) def _validate_insert_value(self, value): if isinstance(value, self._recognized_scalars): value = self._scalar_type(value) - self._check_compatible_with(value, setitem=True) - # TODO: if we dont have compat, should we raise or astype(object)? - # PeriodIndex does astype(object) elif is_valid_nat_for_dtype(value, self.dtype): # GH#18295 value = NaT @@ -857,6 +837,9 @@ def _validate_insert_value(self, value): f"cannot insert {type(self).__name__} with incompatible label" ) + self._check_compatible_with(value, setitem=True) + # TODO: if we dont have compat, should we raise or astype(object)? + # PeriodIndex does astype(object) return value def _validate_where_value(self, other): @@ -864,7 +847,6 @@ def _validate_where_value(self, other): other = NaT elif isinstance(other, self._recognized_scalars): other = self._scalar_type(other) - self._check_compatible_with(other, setitem=True) elif not is_list_like(other): raise TypeError(f"Where requires matching dtype, not {type(other)}") @@ -881,13 +863,20 @@ def _validate_where_value(self, other): if not type(self)._is_recognized_dtype(other.dtype): raise TypeError(f"Where requires matching dtype, not {other.dtype}") - self._check_compatible_with(other, setitem=True) + self._check_compatible_with(other, setitem=True) + return self._unbox(other) + + def _unbox(self, other) -> Union[np.int64, np.ndarray]: + """ + Unbox either a scalar with _unbox_scalar or an instance of our own type. + """ if lib.is_scalar(other): other = self._unbox_scalar(other) else: + # same type as self + self._check_compatible_with(other) other = other.view("i8") - return other # ------------------------------------------------------------------ diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index b7dfcd4cb188c..1460a2e762771 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -249,8 +249,7 @@ def _unbox_scalar(self, value: Union[Period, NaTType]) -> int: if value is NaT: return value.value elif isinstance(value, self._scalar_type): - if not isna(value): - self._check_compatible_with(value) + self._check_compatible_with(value) return value.ordinal else: raise ValueError(f"'value' should be a Period. Got '{value}' instead.")