Skip to content

BUG: arg validation in DTA/TDA/PA.take, DTI/TDI/PI.where #33685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,15 +749,16 @@ def _validate_fill_value(self, fill_value):
------
ValueError
"""
if isna(fill_value):
if is_valid_nat_for_dtype(fill_value, self.dtype):
fill_value = iNaT
elif isinstance(fill_value, self._recognized_scalars):
self._check_compatible_with(fill_value)
fill_value = self._scalar_type(fill_value)
fill_value = self._unbox_scalar(fill_value)
else:
raise ValueError(
f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'."
f"'fill_value' should be a {self._scalar_type}. "
f"Got '{str(fill_value)}'."
)
return fill_value

Expand Down Expand Up @@ -858,8 +859,11 @@ def _validate_insert_value(self, value):
return value

def _validate_where_value(self, other):
if lib.is_scalar(other) and isna(other):
if is_valid_nat_for_dtype(other, self.dtype):
other = NaT.value
elif not is_list_like(other):
# TODO: what about own-type scalars?
raise TypeError(f"Where requires matching dtype, not {type(other)}")

else:
# Do type inference if necessary up front
Expand Down
28 changes: 28 additions & 0 deletions pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,12 @@ def test_take_fill_valid(self, datetime_index, tz_naive_fixture):
# require NaT, not iNaT, as it could be confused with an integer
arr.take([-1, 1], allow_fill=True, fill_value=value)

value = np.timedelta64("NaT", "ns")
msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think normally shouldn't need str in f-string. could this be a numpy bug?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i forget exactly which case broke without the extra str, but plausibly numpy

with pytest.raises(ValueError, match=msg):
# require appropriate-dtype if we have a NA value
arr.take([-1, 1], allow_fill=True, fill_value=value)

def test_concat_same_type_invalid(self, datetime_index):
# different timezones
dti = datetime_index
Expand Down Expand Up @@ -669,6 +675,12 @@ def test_take_fill_valid(self, timedelta_index):
# fill_value Period invalid
arr.take([0, 1], allow_fill=True, fill_value=value)

value = np.datetime64("NaT", "ns")
msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'."
with pytest.raises(ValueError, match=msg):
# require appropriate-dtype if we have a NA value
arr.take([-1, 1], allow_fill=True, fill_value=value)


class TestPeriodArray(SharedTests):
index_cls = pd.PeriodIndex
Expand Down Expand Up @@ -697,6 +709,22 @@ def test_astype_object(self, period_index):
assert asobj.dtype == "O"
assert list(asobj) == list(pi)

def test_take_fill_valid(self, period_index):
pi = period_index
arr = PeriodArray(pi)

value = pd.NaT.value
msg = f"'fill_value' should be a {self.dtype}. Got '{value}'."
with pytest.raises(ValueError, match=msg):
# require NaT, not iNaT, as it could be confused with an integer
arr.take([-1, 1], allow_fill=True, fill_value=value)

value = np.timedelta64("NaT", "ns")
msg = f"'fill_value' should be a {self.dtype}. Got '{str(value)}'."
with pytest.raises(ValueError, match=msg):
# require appropriate-dtype if we have a NA value
arr.take([-1, 1], allow_fill=True, fill_value=value)

@pytest.mark.parametrize("how", ["S", "E"])
def test_to_timestamp(self, how, period_index):
pi = period_index
Expand Down
9 changes: 8 additions & 1 deletion pandas/tests/indexes/datetimes/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def test_where_other(self):
def test_where_invalid_dtypes(self):
dti = pd.date_range("20130101", periods=3, tz="US/Eastern")

i2 = dti.copy()
i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist())

with pytest.raises(TypeError, match="Where requires matching dtype"):
Expand All @@ -194,6 +193,14 @@ def test_where_invalid_dtypes(self):
with pytest.raises(TypeError, match="Where requires matching dtype"):
dti.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
# non-matching scalar
dti.where(notna(i2), pd.Timedelta(days=4))

with pytest.raises(TypeError, match="Where requires matching dtype"):
# non-matching NA value
dti.where(notna(i2), np.timedelta64("NaT", "ns"))

def test_where_tz(self):
i = pd.date_range("20130101", periods=3, tz="US/Eastern")
result = i.where(notna(i))
Expand Down
9 changes: 8 additions & 1 deletion pandas/tests/indexes/period/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,6 @@ def test_where_other(self):
def test_where_invalid_dtypes(self):
pi = period_range("20130101", periods=5, freq="D")

i2 = pi.copy()
i2 = PeriodIndex([NaT, NaT] + pi[2:].tolist(), freq="D")

with pytest.raises(TypeError, match="Where requires matching dtype"):
Expand All @@ -538,6 +537,14 @@ def test_where_invalid_dtypes(self):
with pytest.raises(TypeError, match="Where requires matching dtype"):
pi.where(notna(i2), i2.to_timestamp("S"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
# non-matching scalar
pi.where(notna(i2), Timedelta(days=4))

with pytest.raises(TypeError, match="Where requires matching dtype"):
# non-matching NA value
pi.where(notna(i2), np.timedelta64("NaT", "ns"))


class TestTake:
def test_take(self):
Expand Down
9 changes: 8 additions & 1 deletion pandas/tests/indexes/timedeltas/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def test_where_doesnt_retain_freq(self):
def test_where_invalid_dtypes(self):
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")

i2 = tdi.copy()
i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())

with pytest.raises(TypeError, match="Where requires matching dtype"):
Expand All @@ -160,6 +159,14 @@ def test_where_invalid_dtypes(self):
with pytest.raises(TypeError, match="Where requires matching dtype"):
tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
# non-matching scalar
tdi.where(notna(i2), pd.Timestamp.now())

with pytest.raises(TypeError, match="Where requires matching dtype"):
# non-matching NA value
tdi.where(notna(i2), np.datetime64("NaT", "ns"))


class TestTake:
def test_take(self):
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def test_where_index_datetime(self):
assert obj.dtype == "datetime64[ns]"
cond = pd.Index([True, False, True, False])

msg = "Index\\(\\.\\.\\.\\) must be called with a collection of some kind"
msg = "Where requires matching dtype, not .*Timestamp"
with pytest.raises(TypeError, match=msg):
obj.where(cond, fill_val)

Expand Down