diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 27b2ed822a49f..6489d07ebdc4b 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -749,7 +749,7 @@ def _validate_fill_value(self, fill_value): ------ ValueError """ - if isna(fill_value): + if is_valid_nat_for_dtype(fill_value, self.dtype): fill_value = iNaT elif isinstance(fill_value, self._recognized_scalars): self._check_compatible_with(fill_value) @@ -757,7 +757,8 @@ def _validate_fill_value(self, fill_value): fill_value = self._unbox_scalar(fill_value) else: raise ValueError( - f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'." + f"'fill_value' should be a {self._scalar_type}. " + f"Got '{str(fill_value)}'." ) return fill_value @@ -858,8 +859,11 @@ def _validate_insert_value(self, value): return value def _validate_where_value(self, other): - if lib.is_scalar(other) and isna(other): + if is_valid_nat_for_dtype(other, self.dtype): other = NaT.value + elif not is_list_like(other): + # TODO: what about own-type scalars? + raise TypeError(f"Where requires matching dtype, not {type(other)}") else: # Do type inference if necessary up front diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 5b703cfe8fae5..80739b9512953 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -508,6 +508,12 @@ def test_take_fill_valid(self, datetime_index, tz_naive_fixture): # require NaT, not iNaT, as it could be confused with an integer arr.take([-1, 1], allow_fill=True, fill_value=value) + value = np.timedelta64("NaT", "ns") + msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'." + with pytest.raises(ValueError, match=msg): + # require appropriate-dtype if we have a NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + def test_concat_same_type_invalid(self, datetime_index): # different timezones dti = datetime_index @@ -669,6 +675,12 @@ def test_take_fill_valid(self, timedelta_index): # fill_value Period invalid arr.take([0, 1], allow_fill=True, fill_value=value) + value = np.datetime64("NaT", "ns") + msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'." + with pytest.raises(ValueError, match=msg): + # require appropriate-dtype if we have a NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + class TestPeriodArray(SharedTests): index_cls = pd.PeriodIndex @@ -697,6 +709,22 @@ def test_astype_object(self, period_index): assert asobj.dtype == "O" assert list(asobj) == list(pi) + def test_take_fill_valid(self, period_index): + pi = period_index + arr = PeriodArray(pi) + + value = pd.NaT.value + msg = f"'fill_value' should be a {self.dtype}. Got '{value}'." + with pytest.raises(ValueError, match=msg): + # require NaT, not iNaT, as it could be confused with an integer + arr.take([-1, 1], allow_fill=True, fill_value=value) + + value = np.timedelta64("NaT", "ns") + msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'." + with pytest.raises(ValueError, match=msg): + # require appropriate-dtype if we have a NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + @pytest.mark.parametrize("how", ["S", "E"]) def test_to_timestamp(self, how, period_index): pi = period_index diff --git a/pandas/tests/indexes/datetimes/test_indexing.py b/pandas/tests/indexes/datetimes/test_indexing.py index 08b8e710237c5..f9b8bd27b7f5a 100644 --- a/pandas/tests/indexes/datetimes/test_indexing.py +++ b/pandas/tests/indexes/datetimes/test_indexing.py @@ -174,7 +174,6 @@ def test_where_other(self): def test_where_invalid_dtypes(self): dti = pd.date_range("20130101", periods=3, tz="US/Eastern") - i2 = dti.copy() i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist()) with pytest.raises(TypeError, match="Where requires matching dtype"): @@ -194,6 +193,14 @@ def test_where_invalid_dtypes(self): with pytest.raises(TypeError, match="Where requires matching dtype"): dti.where(notna(i2), i2.asi8) + with pytest.raises(TypeError, match="Where requires matching dtype"): + # non-matching scalar + dti.where(notna(i2), pd.Timedelta(days=4)) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + # non-matching NA value + dti.where(notna(i2), np.timedelta64("NaT", "ns")) + def test_where_tz(self): i = pd.date_range("20130101", periods=3, tz="US/Eastern") result = i.where(notna(i)) diff --git a/pandas/tests/indexes/period/test_indexing.py b/pandas/tests/indexes/period/test_indexing.py index c4aaf6332ba15..bd71c04a9ab03 100644 --- a/pandas/tests/indexes/period/test_indexing.py +++ b/pandas/tests/indexes/period/test_indexing.py @@ -526,7 +526,6 @@ def test_where_other(self): def test_where_invalid_dtypes(self): pi = period_range("20130101", periods=5, freq="D") - i2 = pi.copy() i2 = PeriodIndex([NaT, NaT] + pi[2:].tolist(), freq="D") with pytest.raises(TypeError, match="Where requires matching dtype"): @@ -538,6 +537,14 @@ def test_where_invalid_dtypes(self): with pytest.raises(TypeError, match="Where requires matching dtype"): pi.where(notna(i2), i2.to_timestamp("S")) + with pytest.raises(TypeError, match="Where requires matching dtype"): + # non-matching scalar + pi.where(notna(i2), Timedelta(days=4)) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + # non-matching NA value + pi.where(notna(i2), np.timedelta64("NaT", "ns")) + class TestTake: def test_take(self): diff --git a/pandas/tests/indexes/timedeltas/test_indexing.py b/pandas/tests/indexes/timedeltas/test_indexing.py index 8c39a9c40a69b..17feed3fd7a68 100644 --- a/pandas/tests/indexes/timedeltas/test_indexing.py +++ b/pandas/tests/indexes/timedeltas/test_indexing.py @@ -148,7 +148,6 @@ def test_where_doesnt_retain_freq(self): def test_where_invalid_dtypes(self): tdi = timedelta_range("1 day", periods=3, freq="D", name="idx") - i2 = tdi.copy() i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist()) with pytest.raises(TypeError, match="Where requires matching dtype"): @@ -160,6 +159,14 @@ def test_where_invalid_dtypes(self): with pytest.raises(TypeError, match="Where requires matching dtype"): tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D")) + with pytest.raises(TypeError, match="Where requires matching dtype"): + # non-matching scalar + tdi.where(notna(i2), pd.Timestamp.now()) + + with pytest.raises(TypeError, match="Where requires matching dtype"): + # non-matching NA value + tdi.where(notna(i2), np.datetime64("NaT", "ns")) + class TestTake: def test_take(self): diff --git a/pandas/tests/indexing/test_coercion.py b/pandas/tests/indexing/test_coercion.py index c390347236ad3..6cb73823adabb 100644 --- a/pandas/tests/indexing/test_coercion.py +++ b/pandas/tests/indexing/test_coercion.py @@ -700,7 +700,7 @@ def test_where_index_datetime(self): assert obj.dtype == "datetime64[ns]" cond = pd.Index([True, False, True, False]) - msg = "Index\\(\\.\\.\\.\\) must be called with a collection of some kind" + msg = "Where requires matching dtype, not .*Timestamp" with pytest.raises(TypeError, match=msg): obj.where(cond, fill_val)