Skip to content

Commit f90c4ac

Browse files
jbrockmendelNico Cernek
authored and
Nico Cernek
committed
Change maybe_promote fill_value to dt64/td64 NaT instead of iNaT (pandas-dev#28725)
1 parent e3a36ff commit f90c4ac

File tree

2 files changed

+23
-36
lines changed

2 files changed

+23
-36
lines changed

pandas/core/dtypes/cast.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def maybe_promote(dtype, fill_value=np.nan):
339339
# if we passed an array here, determine the fill value by dtype
340340
if isinstance(fill_value, np.ndarray):
341341
if issubclass(fill_value.dtype.type, (np.datetime64, np.timedelta64)):
342-
fill_value = iNaT
342+
fill_value = fill_value.dtype.type("NaT", "ns")
343343
else:
344344

345345
# we need to change to object type as our
@@ -350,9 +350,14 @@ def maybe_promote(dtype, fill_value=np.nan):
350350

351351
# returns tuple of (dtype, fill_value)
352352
if issubclass(dtype.type, np.datetime64):
353-
fill_value = tslibs.Timestamp(fill_value).value
353+
fill_value = tslibs.Timestamp(fill_value).to_datetime64()
354354
elif issubclass(dtype.type, np.timedelta64):
355-
fill_value = tslibs.Timedelta(fill_value).value
355+
fv = tslibs.Timedelta(fill_value)
356+
if fv is NaT:
357+
# NaT has no `to_timedelta6` method
358+
fill_value = np.timedelta64("NaT", "ns")
359+
else:
360+
fill_value = fv.to_timedelta64()
356361
elif is_datetime64tz_dtype(dtype):
357362
if isna(fill_value):
358363
fill_value = NaT
@@ -393,7 +398,7 @@ def maybe_promote(dtype, fill_value=np.nan):
393398
dtype = np.float64
394399
fill_value = np.nan
395400
elif is_datetime_or_timedelta_dtype(dtype):
396-
fill_value = iNaT
401+
fill_value = dtype.type("NaT", "ns")
397402
else:
398403
dtype = np.object_
399404
fill_value = np.nan

pandas/tests/dtypes/cast/test_promote.py

+14-32
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
is_string_dtype,
2323
is_timedelta64_dtype,
2424
)
25-
from pandas.core.dtypes.dtypes import DatetimeTZDtype, PandasExtensionDtype
25+
from pandas.core.dtypes.dtypes import DatetimeTZDtype
2626
from pandas.core.dtypes.missing import isna
2727

2828
import pandas as pd
@@ -92,20 +92,6 @@ def box(request):
9292
return request.param
9393

9494

95-
def _safe_dtype_assert(left_dtype, right_dtype):
96-
"""
97-
Compare two dtypes without raising TypeError.
98-
"""
99-
__tracebackhide__ = True
100-
if isinstance(right_dtype, PandasExtensionDtype):
101-
# switch order of equality check because numpy dtypes (e.g. if
102-
# left_dtype is np.object_) do not know some expected dtypes (e.g.
103-
# DatetimeTZDtype) and would raise a TypeError in their __eq__-method.
104-
assert right_dtype == left_dtype
105-
else:
106-
assert left_dtype == right_dtype
107-
108-
10995
def _check_promote(
11096
dtype,
11197
fill_value,
@@ -157,8 +143,11 @@ def _check_promote(
157143
result_dtype, result_fill_value = maybe_promote(dtype, fill_value)
158144
expected_fill_value = exp_val_for_scalar
159145

160-
_safe_dtype_assert(result_dtype, expected_dtype)
146+
assert result_dtype == expected_dtype
147+
_assert_match(result_fill_value, expected_fill_value)
148+
161149

150+
def _assert_match(result_fill_value, expected_fill_value):
162151
# GH#23982/25425 require the same type in addition to equality/NA-ness
163152
res_type = type(result_fill_value)
164153
ex_type = type(expected_fill_value)
@@ -369,8 +358,8 @@ def test_maybe_promote_any_with_datetime64(
369358
if is_datetime64_dtype(dtype):
370359
expected_dtype = dtype
371360
# for datetime dtypes, scalar values get cast to pd.Timestamp.value
372-
exp_val_for_scalar = pd.Timestamp(fill_value).value
373-
exp_val_for_array = iNaT
361+
exp_val_for_scalar = pd.Timestamp(fill_value).to_datetime64()
362+
exp_val_for_array = np.datetime64("NaT", "ns")
374363
else:
375364
expected_dtype = np.dtype(object)
376365
exp_val_for_scalar = fill_value
@@ -454,9 +443,7 @@ def test_maybe_promote_datetimetz_with_datetimetz(
454443
)
455444

456445

457-
@pytest.mark.parametrize(
458-
"fill_value", [None, np.nan, NaT, iNaT], ids=["None", "np.nan", "pd.NaT", "iNaT"]
459-
)
446+
@pytest.mark.parametrize("fill_value", [None, np.nan, NaT, iNaT])
460447
# override parametrization due to to many xfails; see GH 23982 / 25425
461448
@pytest.mark.parametrize("box", [(False, None)])
462449
def test_maybe_promote_datetimetz_with_na(tz_aware_fixture, fill_value, box):
@@ -572,8 +559,8 @@ def test_maybe_promote_any_with_timedelta64(
572559
if is_timedelta64_dtype(dtype):
573560
expected_dtype = dtype
574561
# for timedelta dtypes, scalar values get cast to pd.Timedelta.value
575-
exp_val_for_scalar = pd.Timedelta(fill_value).value
576-
exp_val_for_array = iNaT
562+
exp_val_for_scalar = pd.Timedelta(fill_value).to_timedelta64()
563+
exp_val_for_array = np.timedelta64("NaT", "ns")
577564
else:
578565
expected_dtype = np.dtype(object)
579566
exp_val_for_scalar = fill_value
@@ -714,9 +701,7 @@ def test_maybe_promote_any_with_object(any_numpy_dtype_reduced, object_dtype, bo
714701
)
715702

716703

717-
@pytest.mark.parametrize(
718-
"fill_value", [None, np.nan, NaT, iNaT], ids=["None", "np.nan", "pd.NaT", "iNaT"]
719-
)
704+
@pytest.mark.parametrize("fill_value", [None, np.nan, NaT, iNaT])
720705
# override parametrization due to to many xfails; see GH 23982 / 25425
721706
@pytest.mark.parametrize("box", [(False, None)])
722707
def test_maybe_promote_any_numpy_dtype_with_na(
@@ -764,7 +749,7 @@ def test_maybe_promote_any_numpy_dtype_with_na(
764749
elif is_datetime_or_timedelta_dtype(dtype):
765750
# datetime / timedelta cast all missing values to iNaT
766751
expected_dtype = dtype
767-
exp_val_for_scalar = iNaT
752+
exp_val_for_scalar = dtype.type("NaT", "ns")
768753
elif fill_value is NaT:
769754
# NaT upcasts everything that's not datetime/timedelta to object
770755
expected_dtype = np.dtype(object)
@@ -783,7 +768,7 @@ def test_maybe_promote_any_numpy_dtype_with_na(
783768
# integers cannot hold NaNs; maybe_promote_with_array returns None
784769
exp_val_for_array = None
785770
elif is_datetime_or_timedelta_dtype(expected_dtype):
786-
exp_val_for_array = iNaT
771+
exp_val_for_array = expected_dtype.type("NaT", "ns")
787772
else: # expected_dtype = float / complex / object
788773
exp_val_for_array = np.nan
789774

@@ -817,7 +802,4 @@ def test_maybe_promote_dimensions(any_numpy_dtype_reduced, dim):
817802
result_dtype, result_missing_value = maybe_promote(dtype, fill_array)
818803

819804
assert result_dtype == expected_dtype
820-
# None == None, iNaT == iNaT, but np.nan != np.nan
821-
assert (result_missing_value == expected_missing_value) or (
822-
result_missing_value is np.nan and expected_missing_value is np.nan
823-
)
805+
_assert_match(result_missing_value, expected_missing_value)

0 commit comments

Comments
 (0)