Skip to content

Commit 3c09d22

Browse files
authored
BUG: arg validation in DTA/TDA/PA.take, DTI/TDI/PI.where (#33685)
1 parent c6a1638 commit 3c09d22

File tree

6 files changed

+60
-7
lines changed

6 files changed

+60
-7
lines changed

pandas/core/arrays/datetimelike.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -758,15 +758,16 @@ def _validate_fill_value(self, fill_value):
758758
------
759759
ValueError
760760
"""
761-
if isna(fill_value):
761+
if is_valid_nat_for_dtype(fill_value, self.dtype):
762762
fill_value = iNaT
763763
elif isinstance(fill_value, self._recognized_scalars):
764764
self._check_compatible_with(fill_value)
765765
fill_value = self._scalar_type(fill_value)
766766
fill_value = self._unbox_scalar(fill_value)
767767
else:
768768
raise ValueError(
769-
f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'."
769+
f"'fill_value' should be a {self._scalar_type}. "
770+
f"Got '{str(fill_value)}'."
770771
)
771772
return fill_value
772773

@@ -867,8 +868,11 @@ def _validate_insert_value(self, value):
867868
return value
868869

869870
def _validate_where_value(self, other):
870-
if lib.is_scalar(other) and isna(other):
871+
if is_valid_nat_for_dtype(other, self.dtype):
871872
other = NaT.value
873+
elif not is_list_like(other):
874+
# TODO: what about own-type scalars?
875+
raise TypeError(f"Where requires matching dtype, not {type(other)}")
872876

873877
else:
874878
# Do type inference if necessary up front

pandas/tests/arrays/test_datetimelike.py

+28
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ def test_take_fill_valid(self, datetime_index, tz_naive_fixture):
508508
# require NaT, not iNaT, as it could be confused with an integer
509509
arr.take([-1, 1], allow_fill=True, fill_value=value)
510510

511+
value = np.timedelta64("NaT", "ns")
512+
msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'."
513+
with pytest.raises(ValueError, match=msg):
514+
# require appropriate-dtype if we have a NA value
515+
arr.take([-1, 1], allow_fill=True, fill_value=value)
516+
511517
def test_concat_same_type_invalid(self, datetime_index):
512518
# different timezones
513519
dti = datetime_index
@@ -669,6 +675,12 @@ def test_take_fill_valid(self, timedelta_index):
669675
# fill_value Period invalid
670676
arr.take([0, 1], allow_fill=True, fill_value=value)
671677

678+
value = np.datetime64("NaT", "ns")
679+
msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'."
680+
with pytest.raises(ValueError, match=msg):
681+
# require appropriate-dtype if we have a NA value
682+
arr.take([-1, 1], allow_fill=True, fill_value=value)
683+
672684

673685
class TestPeriodArray(SharedTests):
674686
index_cls = pd.PeriodIndex
@@ -697,6 +709,22 @@ def test_astype_object(self, period_index):
697709
assert asobj.dtype == "O"
698710
assert list(asobj) == list(pi)
699711

712+
def test_take_fill_valid(self, period_index):
713+
pi = period_index
714+
arr = PeriodArray(pi)
715+
716+
value = pd.NaT.value
717+
msg = f"'fill_value' should be a {self.dtype}. Got '{value}'."
718+
with pytest.raises(ValueError, match=msg):
719+
# require NaT, not iNaT, as it could be confused with an integer
720+
arr.take([-1, 1], allow_fill=True, fill_value=value)
721+
722+
value = np.timedelta64("NaT", "ns")
723+
msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'."
724+
with pytest.raises(ValueError, match=msg):
725+
# require appropriate-dtype if we have a NA value
726+
arr.take([-1, 1], allow_fill=True, fill_value=value)
727+
700728
@pytest.mark.parametrize("how", ["S", "E"])
701729
def test_to_timestamp(self, how, period_index):
702730
pi = period_index

pandas/tests/indexes/datetimes/test_indexing.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def test_where_other(self):
174174
def test_where_invalid_dtypes(self):
175175
dti = pd.date_range("20130101", periods=3, tz="US/Eastern")
176176

177-
i2 = dti.copy()
178177
i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist())
179178

180179
with pytest.raises(TypeError, match="Where requires matching dtype"):
@@ -194,6 +193,14 @@ def test_where_invalid_dtypes(self):
194193
with pytest.raises(TypeError, match="Where requires matching dtype"):
195194
dti.where(notna(i2), i2.asi8)
196195

196+
with pytest.raises(TypeError, match="Where requires matching dtype"):
197+
# non-matching scalar
198+
dti.where(notna(i2), pd.Timedelta(days=4))
199+
200+
with pytest.raises(TypeError, match="Where requires matching dtype"):
201+
# non-matching NA value
202+
dti.where(notna(i2), np.timedelta64("NaT", "ns"))
203+
197204
def test_where_tz(self):
198205
i = pd.date_range("20130101", periods=3, tz="US/Eastern")
199206
result = i.where(notna(i))

pandas/tests/indexes/period/test_indexing.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,6 @@ def test_where_other(self):
526526
def test_where_invalid_dtypes(self):
527527
pi = period_range("20130101", periods=5, freq="D")
528528

529-
i2 = pi.copy()
530529
i2 = PeriodIndex([NaT, NaT] + pi[2:].tolist(), freq="D")
531530

532531
with pytest.raises(TypeError, match="Where requires matching dtype"):
@@ -538,6 +537,14 @@ def test_where_invalid_dtypes(self):
538537
with pytest.raises(TypeError, match="Where requires matching dtype"):
539538
pi.where(notna(i2), i2.to_timestamp("S"))
540539

540+
with pytest.raises(TypeError, match="Where requires matching dtype"):
541+
# non-matching scalar
542+
pi.where(notna(i2), Timedelta(days=4))
543+
544+
with pytest.raises(TypeError, match="Where requires matching dtype"):
545+
# non-matching NA value
546+
pi.where(notna(i2), np.timedelta64("NaT", "ns"))
547+
541548

542549
class TestTake:
543550
def test_take(self):

pandas/tests/indexes/timedeltas/test_indexing.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def test_where_doesnt_retain_freq(self):
148148
def test_where_invalid_dtypes(self):
149149
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
150150

151-
i2 = tdi.copy()
152151
i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())
153152

154153
with pytest.raises(TypeError, match="Where requires matching dtype"):
@@ -160,6 +159,14 @@ def test_where_invalid_dtypes(self):
160159
with pytest.raises(TypeError, match="Where requires matching dtype"):
161160
tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))
162161

162+
with pytest.raises(TypeError, match="Where requires matching dtype"):
163+
# non-matching scalar
164+
tdi.where(notna(i2), pd.Timestamp.now())
165+
166+
with pytest.raises(TypeError, match="Where requires matching dtype"):
167+
# non-matching NA value
168+
tdi.where(notna(i2), np.datetime64("NaT", "ns"))
169+
163170

164171
class TestTake:
165172
def test_take(self):

pandas/tests/indexing/test_coercion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def test_where_index_datetime(self):
700700
assert obj.dtype == "datetime64[ns]"
701701
cond = pd.Index([True, False, True, False])
702702

703-
msg = "Index\\(\\.\\.\\.\\) must be called with a collection of some kind"
703+
msg = "Where requires matching dtype, not .*Timestamp"
704704
with pytest.raises(TypeError, match=msg):
705705
obj.where(cond, fill_val)
706706

0 commit comments

Comments
 (0)