Skip to content

Commit 76cca0e

Browse files
jbrockmendeljreback
authored andcommitted
BUG: fix+test assigning invalid NAT-like to DTA/TDA/PA (#27331)
1 parent eeff07f commit 76cca0e

File tree

3 files changed

+77
-2
lines changed

3 files changed

+77
-2
lines changed

pandas/core/arrays/datetimelike.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
)
3737
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
3838
from pandas.core.dtypes.inference import is_array_like
39-
from pandas.core.dtypes.missing import isna
39+
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
4040

4141
from pandas._typing import DatetimeLikeScalar
4242
from pandas.core import missing, nanops
@@ -492,7 +492,10 @@ def __setitem__(
492492
elif isinstance(value, self._scalar_type):
493493
self._check_compatible_with(value)
494494
value = self._unbox_scalar(value)
495-
elif isna(value) or value == iNaT:
495+
elif is_valid_nat_for_dtype(value, self.dtype):
496+
value = iNaT
497+
elif not isna(value) and lib.is_integer(value) and value == iNaT:
498+
# exclude misc e.g. object() and any NAs not allowed above
496499
value = iNaT
497500
else:
498501
msg = (

pandas/core/dtypes/missing.py

+24
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,27 @@ def remove_na_arraylike(arr):
559559
return arr[notna(arr)]
560560
else:
561561
return arr[notna(lib.values_from_object(arr))]
562+
563+
564+
def is_valid_nat_for_dtype(obj, dtype):
565+
"""
566+
isna check that excludes incompatible dtypes
567+
568+
Parameters
569+
----------
570+
obj : object
571+
dtype : np.datetime64, np.timedelta64, DatetimeTZDtype, or PeriodDtype
572+
573+
Returns
574+
-------
575+
bool
576+
"""
577+
if not isna(obj):
578+
return False
579+
if dtype.kind == "M":
580+
return not isinstance(obj, np.timedelta64)
581+
if dtype.kind == "m":
582+
return not isinstance(obj, np.datetime64)
583+
584+
# must be PeriodDType
585+
return not isinstance(obj, (np.datetime64, np.timedelta64))

pandas/tests/arrays/test_datetimelike.py

+48
Original file line numberDiff line numberDiff line change
@@ -651,3 +651,51 @@ def test_array_interface(self, period_index):
651651
result = np.asarray(arr, dtype="S20")
652652
expected = np.asarray(arr).astype("S20")
653653
tm.assert_numpy_array_equal(result, expected)
654+
655+
656+
@pytest.mark.parametrize(
657+
"array,casting_nats",
658+
[
659+
(
660+
pd.TimedeltaIndex(["1 Day", "3 Hours", "NaT"])._data,
661+
(pd.NaT, np.timedelta64("NaT", "ns")),
662+
),
663+
(
664+
pd.date_range("2000-01-01", periods=3, freq="D")._data,
665+
(pd.NaT, np.datetime64("NaT", "ns")),
666+
),
667+
(pd.period_range("2000-01-01", periods=3, freq="D")._data, (pd.NaT,)),
668+
],
669+
ids=lambda x: type(x).__name__,
670+
)
671+
def test_casting_nat_setitem_array(array, casting_nats):
672+
expected = type(array)._from_sequence([pd.NaT, array[1], array[2]])
673+
674+
for nat in casting_nats:
675+
arr = array.copy()
676+
arr[0] = nat
677+
tm.assert_equal(arr, expected)
678+
679+
680+
@pytest.mark.parametrize(
681+
"array,non_casting_nats",
682+
[
683+
(
684+
pd.TimedeltaIndex(["1 Day", "3 Hours", "NaT"])._data,
685+
(np.datetime64("NaT", "ns"),),
686+
),
687+
(
688+
pd.date_range("2000-01-01", periods=3, freq="D")._data,
689+
(np.timedelta64("NaT", "ns"),),
690+
),
691+
(
692+
pd.period_range("2000-01-01", periods=3, freq="D")._data,
693+
(np.datetime64("NaT", "ns"), np.timedelta64("NaT", "ns")),
694+
),
695+
],
696+
ids=lambda x: type(x).__name__,
697+
)
698+
def test_invalid_nat_setitem_array(array, non_casting_nats):
699+
for nat in non_casting_nats:
700+
with pytest.raises(TypeError):
701+
array[0] = nat

0 commit comments

Comments
 (0)