Skip to content

Commit 8429441

Browse files
authored
REF: stronger typing in _box_func (#46917)
1 parent a4338a5 commit 8429441

File tree

5 files changed

+20
-18
lines changed

5 files changed

+20
-18
lines changed

pandas/core/arrays/datetimes.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -538,13 +538,10 @@ def _check_compatible_with(self, other, setitem: bool = False):
538538
# -----------------------------------------------------------------
539539
# Descriptive Properties
540540

541-
def _box_func(self, x) -> Timestamp | NaTType:
542-
if isinstance(x, np.datetime64):
543-
# GH#42228
544-
# Argument 1 to "signedinteger" has incompatible type "datetime64";
545-
# expected "Union[SupportsInt, Union[str, bytes], SupportsIndex]"
546-
x = np.int64(x) # type: ignore[arg-type]
547-
ts = Timestamp(x, tz=self.tz)
541+
def _box_func(self, x: np.datetime64) -> Timestamp | NaTType:
542+
# GH#42228
543+
value = x.view("i8")
544+
ts = Timestamp(value, tz=self.tz)
548545
# Non-overlapping identity check (left operand type: "Timestamp",
549546
# right operand type: "NaTType")
550547
if ts is not NaT: # type: ignore[comparison-overlap]

pandas/core/arrays/timedeltas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class TimedeltaArray(dtl.TimelikeOps):
154154
# Note: ndim must be defined to ensure NaT.__richcmp__(TimedeltaArray)
155155
# operates pointwise.
156156

157-
def _box_func(self, x) -> Timedelta | NaTType:
157+
def _box_func(self, x: np.timedelta64) -> Timedelta | NaTType:
158158
return Timedelta(x, unit="ns")
159159

160160
@property

pandas/core/nanops.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from pandas._libs import (
1717
NaT,
1818
NaTType,
19-
Timedelta,
2019
iNaT,
2120
lib,
2221
)
@@ -367,19 +366,23 @@ def _wrap_results(result, dtype: np.dtype, fill_value=None):
367366
result = np.datetime64("NaT", "ns")
368367
else:
369368
result = np.int64(result).view("datetime64[ns]")
369+
# retain original unit
370+
result = result.astype(dtype, copy=False)
370371
else:
371372
# If we have float dtype, taking a view will give the wrong result
372373
result = result.astype(dtype)
373374
elif is_timedelta64_dtype(dtype):
374375
if not isinstance(result, np.ndarray):
375-
if result == fill_value:
376-
result = np.nan
376+
if result == fill_value or np.isnan(result):
377+
result = np.timedelta64("NaT").astype(dtype)
377378

378-
# raise if we have a timedelta64[ns] which is too large
379-
if np.fabs(result) > lib.i8max:
379+
elif np.fabs(result) > lib.i8max:
380+
# raise if we have a timedelta64[ns] which is too large
380381
raise ValueError("overflow in timedelta operation")
382+
else:
383+
# return a timedelta64 with the original unit
384+
result = np.int64(result).astype(dtype, copy=False)
381385

382-
result = Timedelta(result, unit="ns")
383386
else:
384387
result = result.astype("m8[ns]").view(dtype)
385388

@@ -641,7 +644,7 @@ def _mask_datetimelike_result(
641644
result[axis_mask] = iNaT # type: ignore[index]
642645
else:
643646
if mask.any():
644-
return NaT
647+
return np.int64(iNaT).view(orig_values.dtype)
645648
return result
646649

647650

pandas/tests/arrays/timedeltas/test_reductions.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_std(self, add):
147147

148148
if getattr(arr, "tz", None) is None:
149149
result = nanops.nanstd(np.asarray(arr), skipna=True)
150-
assert isinstance(result, Timedelta)
150+
assert isinstance(result, np.timedelta64)
151151
assert result == expected
152152

153153
result = arr.std(skipna=False)
@@ -158,7 +158,8 @@ def test_std(self, add):
158158

159159
if getattr(arr, "tz", None) is None:
160160
result = nanops.nanstd(np.asarray(arr), skipna=False)
161-
assert result is pd.NaT
161+
assert isinstance(result, np.timedelta64)
162+
assert np.isnat(result)
162163

163164
def test_median(self):
164165
tdi = pd.TimedeltaIndex(["0H", "3H", "NaT", "5H06m", "0H", "2H"])

pandas/tests/test_nanops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,8 @@ def test_nanmean_skipna_false(self, dtype):
10201020
arr[-1, -1] = "NaT"
10211021

10221022
result = nanops.nanmean(arr, skipna=False)
1023-
assert result is pd.NaT
1023+
assert np.isnat(result)
1024+
assert result.dtype == dtype
10241025

10251026
result = nanops.nanmean(arr, axis=0, skipna=False)
10261027
expected = np.array([4, 5, "NaT"], dtype=arr.dtype)

0 commit comments

Comments
 (0)