diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 0fadf3a05a1d3..4a37b4f0f29ba 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -34,7 +34,7 @@ is_unsigned_integer_dtype, pandas_dtype, ) -from pandas.core.dtypes.generic import ABCIndexClass, ABCPeriodArray, ABCSeries +from pandas.core.dtypes.generic import ABCSeries from pandas.core.dtypes.inference import is_array_like from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna @@ -368,16 +368,19 @@ class TimelikeOps: def _round(self, freq, mode, ambiguous, nonexistent): # round the local times - values = _ensure_datetimelike_to_i8(self) + if is_datetime64tz_dtype(self): + # operate on naive timestamps, then convert back to aware + naive = self.tz_localize(None) + result = naive._round(freq, mode, ambiguous, nonexistent) + aware = result.tz_localize( + self.tz, ambiguous=ambiguous, nonexistent=nonexistent + ) + return aware + + values = self.view("i8") result = round_nsint64(values, mode, freq) result = self._maybe_mask_results(result, fill_value=NaT) - - dtype = self.dtype - if is_datetime64tz_dtype(self): - dtype = None - return self._ensure_localized( - self._simple_new(result, dtype=dtype), ambiguous, nonexistent - ) + return self._simple_new(result, dtype=self.dtype) @Appender((_round_doc + _round_example).format(op="round")) def round(self, freq, ambiguous="raise", nonexistent="raise"): @@ -1411,45 +1414,6 @@ def __isub__(self, other): # type: ignore self._freq = result._freq return self - # -------------------------------------------------------------- - # Comparison Methods - - def _ensure_localized( - self, arg, ambiguous="raise", nonexistent="raise", from_utc=False - ): - """ - Ensure that we are re-localized. - - This is for compat as we can then call this on all datetimelike - arrays generally (ignored for Period/Timedelta) - - Parameters - ---------- - arg : Union[DatetimeLikeArray, DatetimeIndexOpsMixin, ndarray] - ambiguous : str, bool, or bool-ndarray, default 'raise' - nonexistent : str, default 'raise' - from_utc : bool, default False - If True, localize the i8 ndarray to UTC first before converting to - the appropriate tz. If False, localize directly to the tz. - - Returns - ------- - localized array - """ - - # reconvert to local tz - tz = getattr(self, "tz", None) - if tz is not None: - if not isinstance(arg, type(self)): - arg = self._simple_new(arg) - if from_utc: - arg = arg.tz_localize("UTC").tz_convert(self.tz) - else: - arg = arg.tz_localize( - self.tz, ambiguous=ambiguous, nonexistent=nonexistent - ) - return arg - # -------------------------------------------------------------- # Reductions @@ -1687,38 +1651,3 @@ def maybe_infer_freq(freq): freq_infer = True freq = None return freq, freq_infer - - -def _ensure_datetimelike_to_i8(other, to_utc=False): - """ - Helper for coercing an input scalar or array to i8. - - Parameters - ---------- - other : 1d array - to_utc : bool, default False - If True, convert the values to UTC before extracting the i8 values - If False, extract the i8 values directly. - - Returns - ------- - i8 1d array - """ - from pandas import Index - - if lib.is_scalar(other) and isna(other): - return iNaT - elif isinstance(other, (ABCPeriodArray, ABCIndexClass, DatetimeLikeArrayMixin)): - # convert tz if needed - if getattr(other, "tz", None) is not None: - if to_utc: - other = other.tz_convert("UTC") - else: - other = other.tz_localize(None) - else: - try: - return np.array(other, copy=False).view("i8") - except TypeError: - # period array cannot be coerced to int - other = Index(other) - return other.asi8 diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 3a58794a8b19e..bb8e04e87724f 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -16,15 +16,18 @@ from pandas.core.dtypes.common import ( ensure_int64, is_bool_dtype, + is_categorical_dtype, is_dtype_equal, is_float, is_integer, is_list_like, is_period_dtype, is_scalar, + needs_i8_conversion, ) from pandas.core.dtypes.concat import concat_compat from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries +from pandas.core.dtypes.missing import isna from pandas.core import algorithms from pandas.core.accessor import PandasDelegate @@ -34,10 +37,7 @@ ExtensionOpsMixin, TimedeltaArray, ) -from pandas.core.arrays.datetimelike import ( - DatetimeLikeArrayMixin, - _ensure_datetimelike_to_i8, -) +from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin import pandas.core.indexes.base as ibase from pandas.core.indexes.base import Index, _index_shared_docs from pandas.core.indexes.numeric import Int64Index @@ -177,18 +177,6 @@ def equals(self, other) -> bool: return np.array_equal(self.asi8, other.asi8) - def _ensure_localized( - self, arg, ambiguous="raise", nonexistent="raise", from_utc=False - ): - # See DatetimeLikeArrayMixin._ensure_localized.__doc__ - if getattr(self, "tz", None): - # ensure_localized is only relevant for tz-aware DTI - result = self._data._ensure_localized( - arg, ambiguous=ambiguous, nonexistent=nonexistent, from_utc=from_utc - ) - return type(self)._simple_new(result, name=self.name) - return arg - @Appender(_index_shared_docs["contains"] % _index_doc_kwargs) def __contains__(self, key): try: @@ -491,11 +479,27 @@ def repeat(self, repeats, axis=None): @Appender(_index_shared_docs["where"] % _index_doc_kwargs) def where(self, cond, other=None): - other = _ensure_datetimelike_to_i8(other, to_utc=True) - values = _ensure_datetimelike_to_i8(self, to_utc=True) - result = np.where(cond, values, other).astype("i8") + values = self.view("i8") + + if is_scalar(other) and isna(other): + other = NaT.value - result = self._ensure_localized(result, from_utc=True) + else: + # Do type inference if necessary up front + # e.g. we passed PeriodIndex.values and got an ndarray of Periods + other = Index(other) + + if is_categorical_dtype(other): + # e.g. we have a Categorical holding self.dtype + if needs_i8_conversion(other.categories): + other = other._internal_get_values() + + if not is_dtype_equal(self.dtype, other.dtype): + raise TypeError(f"Where requires matching dtype, not {other.dtype}") + + other = other.view("i8") + + result = np.where(cond, values, other).astype("i8") return self._shallow_copy(result) def _summary(self, name=None): diff --git a/pandas/tests/indexes/datetimes/test_indexing.py b/pandas/tests/indexes/datetimes/test_indexing.py index 210b28aa0c393..97290c8c635b8 100644 --- a/pandas/tests/indexes/datetimes/test_indexing.py +++ b/pandas/tests/indexes/datetimes/test_indexing.py @@ -132,9 +132,32 @@ def test_where_other(self): i2 = i.copy() i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist()) - result = i.where(notna(i2), i2.values) + result = i.where(notna(i2), i2._values) tm.assert_index_equal(result, i2) + def test_where_invalid_dtypes(self): + dti = pd.date_range("20130101", periods=3, tz="US/Eastern") + + i2 = dti.copy() + i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist()) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + # passing tz-naive ndarray to tzaware DTI + dti.where(notna(i2), i2.values) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + # passing tz-aware DTI to tznaive DTI + dti.tz_localize(None).where(notna(i2), i2) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + dti.where(notna(i2), i2.tz_localize(None).to_period("D")) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + dti.where(notna(i2), i2.asi8.view("timedelta64[ns]")) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + dti.where(notna(i2), i2.asi8) + def test_where_tz(self): i = pd.date_range("20130101", periods=3, tz="US/Eastern") result = i.where(notna(i)) diff --git a/pandas/tests/indexes/period/test_indexing.py b/pandas/tests/indexes/period/test_indexing.py index e95b4ae5397b4..7dbefbdaff98e 100644 --- a/pandas/tests/indexes/period/test_indexing.py +++ b/pandas/tests/indexes/period/test_indexing.py @@ -235,6 +235,21 @@ def test_where_other(self): result = i.where(notna(i2), i2.values) tm.assert_index_equal(result, i2) + def test_where_invalid_dtypes(self): + pi = period_range("20130101", periods=5, freq="D") + + i2 = pi.copy() + i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + pi[2:].tolist(), freq="D") + + with pytest.raises(TypeError, match="Where requires matching dtype"): + pi.where(notna(i2), i2.asi8) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + pi.where(notna(i2), i2.asi8.view("timedelta64[ns]")) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + pi.where(notna(i2), i2.to_timestamp("S")) + class TestTake: def test_take(self): diff --git a/pandas/tests/indexes/timedeltas/test_indexing.py b/pandas/tests/indexes/timedeltas/test_indexing.py index b70a3d17a10ab..36105477ba9ee 100644 --- a/pandas/tests/indexes/timedeltas/test_indexing.py +++ b/pandas/tests/indexes/timedeltas/test_indexing.py @@ -4,7 +4,7 @@ import pytest import pandas as pd -from pandas import Index, Timedelta, TimedeltaIndex, timedelta_range +from pandas import Index, Timedelta, TimedeltaIndex, notna, timedelta_range import pandas._testing as tm @@ -58,8 +58,20 @@ def test_timestamp_invalid_key(self, key): class TestWhere: - # placeholder for symmetry with DatetimeIndex and PeriodIndex tests - pass + def test_where_invalid_dtypes(self): + tdi = timedelta_range("1 day", periods=3, freq="D", name="idx") + + i2 = tdi.copy() + i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist()) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + tdi.where(notna(i2), i2.asi8) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + tdi.where(notna(i2), i2 + pd.Timestamp.now()) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D")) class TestTake: