diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 430f20b359f8b..27b2ed822a49f 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -588,9 +588,6 @@ def __setitem__( # to a period in from_sequence). For DatetimeArray, it's Timestamp... # I don't know if mypy can do that, possibly with Generics. # https://mypy.readthedocs.io/en/latest/generics.html - if lib.is_scalar(value) and not isna(value): - value = com.maybe_box_datetimelike(value) - if is_list_like(value): is_slice = isinstance(key, slice) @@ -609,21 +606,7 @@ def __setitem__( elif not len(key): return - value = type(self)._from_sequence(value, dtype=self.dtype) - self._check_compatible_with(value, setitem=True) - value = value.asi8 - elif isinstance(value, self._scalar_type): - self._check_compatible_with(value, setitem=True) - value = self._unbox_scalar(value) - elif is_valid_nat_for_dtype(value, self.dtype): - value = iNaT - else: - msg = ( - f"'value' should be a '{self._scalar_type.__name__}', 'NaT', " - f"or array of those. Got '{type(value).__name__}' instead." - ) - raise TypeError(msg) - + value = self._validate_setitem_value(value) key = check_array_indexer(self, key) self._data[key] = value self._maybe_clear_freq() @@ -682,35 +665,6 @@ def unique(self): result = unique1d(self.asi8) return type(self)(result, dtype=self.dtype) - def _validate_fill_value(self, fill_value): - """ - If a fill_value is passed to `take` convert it to an i8 representation, - raising ValueError if this is not possible. - - Parameters - ---------- - fill_value : object - - Returns - ------- - fill_value : np.int64 - - Raises - ------ - ValueError - """ - if isna(fill_value): - fill_value = iNaT - elif isinstance(fill_value, self._recognized_scalars): - self._check_compatible_with(fill_value) - fill_value = self._scalar_type(fill_value) - fill_value = self._unbox_scalar(fill_value) - else: - raise ValueError( - f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'." - ) - return fill_value - def take(self, indices, allow_fill=False, fill_value=None): if allow_fill: fill_value = self._validate_fill_value(fill_value) @@ -769,6 +723,45 @@ def shift(self, periods=1, fill_value=None, axis=0): if not self.size or periods == 0: return self.copy() + fill_value = self._validate_shift_value(fill_value) + new_values = shift(self._data, periods, axis, fill_value) + + return type(self)._simple_new(new_values, dtype=self.dtype) + + # ------------------------------------------------------------------ + # Validation Methods + # TODO: try to de-duplicate these, ensure identical behavior + + def _validate_fill_value(self, fill_value): + """ + If a fill_value is passed to `take` convert it to an i8 representation, + raising ValueError if this is not possible. + + Parameters + ---------- + fill_value : object + + Returns + ------- + fill_value : np.int64 + + Raises + ------ + ValueError + """ + if isna(fill_value): + fill_value = iNaT + elif isinstance(fill_value, self._recognized_scalars): + self._check_compatible_with(fill_value) + fill_value = self._scalar_type(fill_value) + fill_value = self._unbox_scalar(fill_value) + else: + raise ValueError( + f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'." + ) + return fill_value + + def _validate_shift_value(self, fill_value): # TODO(2.0): once this deprecation is enforced, used _validate_fill_value if is_valid_nat_for_dtype(fill_value, self.dtype): fill_value = NaT @@ -787,15 +780,104 @@ def shift(self, periods=1, fill_value=None, axis=0): "will raise in a future version, pass " f"{self._scalar_type.__name__} instead.", FutureWarning, - stacklevel=9, + stacklevel=10, ) fill_value = new_fill fill_value = self._unbox_scalar(fill_value) + return fill_value - new_values = shift(self._data, periods, axis, fill_value) + def _validate_searchsorted_value(self, value): + if isinstance(value, str): + try: + value = self._scalar_from_string(value) + except ValueError as err: + raise TypeError( + "searchsorted requires compatible dtype or scalar" + ) from err - return type(self)._simple_new(new_values, dtype=self.dtype) + elif is_valid_nat_for_dtype(value, self.dtype): + value = NaT + + elif isinstance(value, self._recognized_scalars): + value = self._scalar_type(value) + + elif is_list_like(value) and not isinstance(value, type(self)): + value = array(value) + + if not type(self)._is_recognized_dtype(value): + raise TypeError( + "searchsorted requires compatible dtype or scalar, " + f"not {type(value).__name__}" + ) + + if not (isinstance(value, (self._scalar_type, type(self))) or (value is NaT)): + raise TypeError(f"Unexpected type for 'value': {type(value)}") + + if isinstance(value, type(self)): + self._check_compatible_with(value) + value = value.asi8 + else: + value = self._unbox_scalar(value) + + return value + + def _validate_setitem_value(self, value): + if lib.is_scalar(value) and not isna(value): + value = com.maybe_box_datetimelike(value) + + if is_list_like(value): + value = type(self)._from_sequence(value, dtype=self.dtype) + self._check_compatible_with(value, setitem=True) + value = value.asi8 + elif isinstance(value, self._scalar_type): + self._check_compatible_with(value, setitem=True) + value = self._unbox_scalar(value) + elif is_valid_nat_for_dtype(value, self.dtype): + value = iNaT + else: + msg = ( + f"'value' should be a '{self._scalar_type.__name__}', 'NaT', " + f"or array of those. Got '{type(value).__name__}' instead." + ) + raise TypeError(msg) + + return value + + def _validate_insert_value(self, value): + if isinstance(value, self._recognized_scalars): + value = self._scalar_type(value) + elif is_valid_nat_for_dtype(value, self.dtype): + # GH#18295 + value = NaT + elif lib.is_scalar(value) and isna(value): + raise TypeError( + f"cannot insert {type(self).__name__} with incompatible label" + ) + + return value + + def _validate_where_value(self, other): + if lib.is_scalar(other) and isna(other): + other = NaT.value + + else: + # Do type inference if necessary up front + # e.g. we passed PeriodIndex.values and got an ndarray of Periods + from pandas import Index + + other = Index(other) + + if is_categorical_dtype(other): + # e.g. we have a Categorical holding self.dtype + if is_dtype_equal(other.categories.dtype, self.dtype): + other = other._internal_get_values() + + if not is_dtype_equal(self.dtype, other.dtype): + raise TypeError(f"Where requires matching dtype, not {other.dtype}") + + other = other.view("i8") + return other # ------------------------------------------------------------------ # Additional array methods @@ -827,37 +909,7 @@ def searchsorted(self, value, side="left", sorter=None): indices : array of ints Array of insertion points with the same shape as `value`. """ - if isinstance(value, str): - try: - value = self._scalar_from_string(value) - except ValueError as e: - raise TypeError( - "searchsorted requires compatible dtype or scalar" - ) from e - - elif is_valid_nat_for_dtype(value, self.dtype): - value = NaT - - elif isinstance(value, self._recognized_scalars): - value = self._scalar_type(value) - - elif is_list_like(value) and not isinstance(value, type(self)): - value = array(value) - - if not type(self)._is_recognized_dtype(value): - raise TypeError( - "searchsorted requires compatible dtype or scalar, " - f"not {type(value).__name__}" - ) - - if not (isinstance(value, (self._scalar_type, type(self))) or (value is NaT)): - raise TypeError(f"Unexpected type for 'value': {type(value)}") - - if isinstance(value, type(self)): - self._check_compatible_with(value) - value = value.asi8 - else: - value = self._unbox_scalar(value) + value = self._validate_searchsorted_value(value) # TODO: Use datetime64 semantics for sorting, xref GH#29844 return self.asi8.searchsorted(value, side=side, sorter=sorter) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 067ff32b85862..3a721d8c8c320 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -17,7 +17,6 @@ ensure_int64, ensure_platform_int, is_bool_dtype, - is_categorical_dtype, is_dtype_equal, is_integer, is_list_like, @@ -26,7 +25,6 @@ ) from pandas.core.dtypes.concat import concat_compat from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries -from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna from pandas.core import algorithms from pandas.core.arrays import DatetimeArray, PeriodArray, TimedeltaArray @@ -494,23 +492,7 @@ def isin(self, values, level=None): def where(self, cond, other=None): values = self.view("i8") - if is_scalar(other) and isna(other): - other = NaT.value - - else: - # Do type inference if necessary up front - # e.g. we passed PeriodIndex.values and got an ndarray of Periods - other = Index(other) - - if is_categorical_dtype(other): - # e.g. we have a Categorical holding self.dtype - if is_dtype_equal(other.categories.dtype, self.dtype): - other = other._internal_get_values() - - if not is_dtype_equal(self.dtype, other.dtype): - raise TypeError(f"Where requires matching dtype, not {other.dtype}") - - other = other.view("i8") + other = self._data._validate_where_value(other) result = np.where(cond, values, other).astype("i8") arr = type(self._data)._simple_new(result, dtype=self.dtype) @@ -923,15 +905,7 @@ def insert(self, loc, item): ------- new_index : Index """ - if isinstance(item, self._data._recognized_scalars): - item = self._data._scalar_type(item) - elif is_valid_nat_for_dtype(item, self.dtype): - # GH 18295 - item = self._na_value - elif is_scalar(item) and isna(item): - raise TypeError( - f"cannot insert {type(self).__name__} with incompatible label" - ) + item = self._data._validate_insert_value(item) freq = None if isinstance(item, self._data._scalar_type) or item is NaT: