Skip to content

Commit 0fc73ec

Browse files
jbrockmendelrhshadrach
authored andcommitted
BUG: DTI/TDI/PI.where accepting incorrectly-typed NaTs (pandas-dev#33715)
1 parent 1271c2c commit 0fc73ec

File tree

6 files changed

+96
-31
lines changed

6 files changed

+96
-31
lines changed

pandas/core/arrays/datetimelike.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -869,27 +869,33 @@ def _validate_insert_value(self, value):
869869

870870
def _validate_where_value(self, other):
871871
if is_valid_nat_for_dtype(other, self.dtype):
872-
other = NaT.value
872+
other = NaT
873+
elif isinstance(other, self._recognized_scalars):
874+
other = self._scalar_type(other)
875+
self._check_compatible_with(other, setitem=True)
873876
elif not is_list_like(other):
874-
# TODO: what about own-type scalars?
875877
raise TypeError(f"Where requires matching dtype, not {type(other)}")
876878

877879
else:
878880
# Do type inference if necessary up front
879881
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
880-
from pandas import Index
881-
882-
other = Index(other)
882+
other = array(other)
883+
other = extract_array(other, extract_numpy=True)
883884

884-
if is_categorical_dtype(other):
885+
if is_categorical_dtype(other.dtype):
885886
# e.g. we have a Categorical holding self.dtype
886887
if is_dtype_equal(other.categories.dtype, self.dtype):
887888
other = other._internal_get_values()
888889

889-
if not is_dtype_equal(self.dtype, other.dtype):
890+
if not type(self)._is_recognized_dtype(other.dtype):
890891
raise TypeError(f"Where requires matching dtype, not {other.dtype}")
892+
self._check_compatible_with(other, setitem=True)
891893

894+
if lib.is_scalar(other):
895+
other = self._unbox_scalar(other)
896+
else:
892897
other = other.view("i8")
898+
893899
return other
894900

895901
# ------------------------------------------------------------------

pandas/core/indexes/datetimelike.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,12 @@ def isin(self, values, level=None):
530530
def where(self, cond, other=None):
531531
values = self.view("i8")
532532

533-
other = self._data._validate_where_value(other)
533+
try:
534+
other = self._data._validate_where_value(other)
535+
except (TypeError, ValueError) as err:
536+
# Includes tzawareness mismatch and IncompatibleFrequencyError
537+
oth = getattr(other, "dtype", other)
538+
raise TypeError(f"Where requires matching dtype, not {oth}") from err
534539

535540
result = np.where(cond, values, other).astype("i8")
536541
arr = type(self._data)._simple_new(result, dtype=self.dtype)

pandas/tests/indexes/datetimes/test_indexing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,15 @@ def test_where_invalid_dtypes(self):
197197
# non-matching scalar
198198
dti.where(notna(i2), pd.Timedelta(days=4))
199199

200-
with pytest.raises(TypeError, match="Where requires matching dtype"):
201-
# non-matching NA value
202-
dti.where(notna(i2), np.timedelta64("NaT", "ns"))
200+
def test_where_mismatched_nat(self, tz_aware_fixture):
201+
tz = tz_aware_fixture
202+
dti = pd.date_range("2013-01-01", periods=3, tz=tz)
203+
cond = np.array([True, False, True])
204+
205+
msg = "Where requires matching dtype"
206+
with pytest.raises(TypeError, match=msg):
207+
# wrong-dtyped NaT
208+
dti.where(cond, np.timedelta64("NaT", "ns"))
203209

204210
def test_where_tz(self):
205211
i = pd.date_range("20130101", periods=3, tz="US/Eastern")

pandas/tests/indexes/period/test_indexing.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,14 @@ def test_where_invalid_dtypes(self):
541541
# non-matching scalar
542542
pi.where(notna(i2), Timedelta(days=4))
543543

544-
with pytest.raises(TypeError, match="Where requires matching dtype"):
545-
# non-matching NA value
546-
pi.where(notna(i2), np.timedelta64("NaT", "ns"))
544+
def test_where_mismatched_nat(self):
545+
pi = period_range("20130101", periods=5, freq="D")
546+
cond = np.array([True, False, True, True, False])
547+
548+
msg = "Where requires matching dtype"
549+
with pytest.raises(TypeError, match=msg):
550+
# wrong-dtyped NaT
551+
pi.where(cond, np.timedelta64("NaT", "ns"))
547552

548553

549554
class TestTake:

pandas/tests/indexes/timedeltas/test_indexing.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,14 @@ def test_where_invalid_dtypes(self):
163163
# non-matching scalar
164164
tdi.where(notna(i2), pd.Timestamp.now())
165165

166-
with pytest.raises(TypeError, match="Where requires matching dtype"):
167-
# non-matching NA value
168-
tdi.where(notna(i2), np.datetime64("NaT", "ns"))
166+
def test_where_mismatched_nat(self):
167+
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
168+
cond = np.array([True, False, False])
169+
170+
msg = "Where requires matching dtype"
171+
with pytest.raises(TypeError, match=msg):
172+
# wrong-dtyped NaT
173+
tdi.where(cond, np.datetime64("NaT", "ns"))
169174

170175

171176
class TestTake:

pandas/tests/indexing/test_coercion.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import timedelta
12
import itertools
23
from typing import Dict, List
34

@@ -686,8 +687,15 @@ def test_where_series_datetime64(self, fill_val, exp_dtype):
686687
)
687688
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
688689

689-
def test_where_index_datetime(self):
690-
fill_val = pd.Timestamp("2012-01-01")
690+
@pytest.mark.parametrize(
691+
"fill_val",
692+
[
693+
pd.Timestamp("2012-01-01"),
694+
pd.Timestamp("2012-01-01").to_datetime64(),
695+
pd.Timestamp("2012-01-01").to_pydatetime(),
696+
],
697+
)
698+
def test_where_index_datetime(self, fill_val):
691699
exp_dtype = "datetime64[ns]"
692700
obj = pd.Index(
693701
[
@@ -700,9 +708,9 @@ def test_where_index_datetime(self):
700708
assert obj.dtype == "datetime64[ns]"
701709
cond = pd.Index([True, False, True, False])
702710

703-
msg = "Where requires matching dtype, not .*Timestamp"
704-
with pytest.raises(TypeError, match=msg):
705-
obj.where(cond, fill_val)
711+
result = obj.where(cond, fill_val)
712+
expected = pd.DatetimeIndex([obj[0], fill_val, obj[2], fill_val])
713+
tm.assert_index_equal(result, expected)
706714

707715
values = pd.Index(pd.date_range(fill_val, periods=4))
708716
exp = pd.Index(
@@ -717,7 +725,7 @@ def test_where_index_datetime(self):
717725
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
718726

719727
@pytest.mark.xfail(reason="GH 22839: do not ignore timezone, must be object")
720-
def test_where_index_datetimetz(self):
728+
def test_where_index_datetime64tz(self):
721729
fill_val = pd.Timestamp("2012-01-01", tz="US/Eastern")
722730
exp_dtype = np.object
723731
obj = pd.Index(
@@ -754,23 +762,53 @@ def test_where_index_complex128(self):
754762
def test_where_index_bool(self):
755763
pass
756764

757-
def test_where_series_datetime64tz(self):
758-
pass
759-
760765
def test_where_series_timedelta64(self):
761766
pass
762767

763768
def test_where_series_period(self):
764769
pass
765770

766-
def test_where_index_datetime64tz(self):
767-
pass
771+
@pytest.mark.parametrize(
772+
"value", [pd.Timedelta(days=9), timedelta(days=9), np.timedelta64(9, "D")]
773+
)
774+
def test_where_index_timedelta64(self, value):
775+
tdi = pd.timedelta_range("1 Day", periods=4)
776+
cond = np.array([True, False, False, True])
768777

769-
def test_where_index_timedelta64(self):
770-
pass
778+
expected = pd.TimedeltaIndex(["1 Day", value, value, "4 Days"])
779+
result = tdi.where(cond, value)
780+
tm.assert_index_equal(result, expected)
781+
782+
msg = "Where requires matching dtype"
783+
with pytest.raises(TypeError, match=msg):
784+
# wrong-dtyped NaT
785+
tdi.where(cond, np.datetime64("NaT", "ns"))
771786

772787
def test_where_index_period(self):
773-
pass
788+
dti = pd.date_range("2016-01-01", periods=3, freq="QS")
789+
pi = dti.to_period("Q")
790+
791+
cond = np.array([False, True, False])
792+
793+
# Passinga valid scalar
794+
value = pi[-1] + pi.freq * 10
795+
expected = pd.PeriodIndex([value, pi[1], value])
796+
result = pi.where(cond, value)
797+
tm.assert_index_equal(result, expected)
798+
799+
# Case passing ndarray[object] of Periods
800+
other = np.asarray(pi + pi.freq * 10, dtype=object)
801+
result = pi.where(cond, other)
802+
expected = pd.PeriodIndex([other[0], pi[1], other[2]])
803+
tm.assert_index_equal(result, expected)
804+
805+
# Passing a mismatched scalar
806+
msg = "Where requires matching dtype"
807+
with pytest.raises(TypeError, match=msg):
808+
pi.where(cond, pd.Timedelta(days=4))
809+
810+
with pytest.raises(TypeError, match=msg):
811+
pi.where(cond, pd.Period("2020-04-21", "D"))
774812

775813

776814
class TestFillnaSeriesCoercion(CoercionBase):

0 commit comments

Comments
 (0)