Skip to content

Commit 5ad4e57

Browse files
authored
REF: share .astype code for astype_nansafe + TDA.astype (#38481)
1 parent 8455f57 commit 5ad4e57

File tree

4 files changed

+59
-52
lines changed

4 files changed

+59
-52
lines changed

pandas/core/arrays/timedeltas.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from pandas.compat.numpy import function as nv
2626

27+
from pandas.core.dtypes.cast import astype_td64_unit_conversion
2728
from pandas.core.dtypes.common import (
2829
DT64NS_DTYPE,
2930
TD64NS_DTYPE,
@@ -35,7 +36,6 @@
3536
is_scalar,
3637
is_string_dtype,
3738
is_timedelta64_dtype,
38-
is_timedelta64_ns_dtype,
3939
pandas_dtype,
4040
)
4141
from pandas.core.dtypes.dtypes import DatetimeTZDtype
@@ -324,22 +324,14 @@ def astype(self, dtype, copy: bool = True):
324324
# DatetimeLikeArrayMixin super call handles other cases
325325
dtype = pandas_dtype(dtype)
326326

327-
if is_timedelta64_dtype(dtype) and not is_timedelta64_ns_dtype(dtype):
328-
# by pandas convention, converting to non-nano timedelta64
329-
# returns an int64-dtyped array with ints representing multiples
330-
# of the desired timedelta unit. This is essentially division
331-
if self._hasnans:
332-
# avoid double-copying
333-
result = self._data.astype(dtype, copy=False)
334-
return self._maybe_mask_results(
335-
result, fill_value=None, convert="float64"
336-
)
337-
result = self._data.astype(dtype, copy=copy)
338-
return result.astype("i8")
339-
elif is_timedelta64_ns_dtype(dtype):
327+
if is_dtype_equal(dtype, self.dtype):
340328
if copy:
341329
return self.copy()
342330
return self
331+
332+
elif dtype.kind == "m":
333+
return astype_td64_unit_conversion(self._data, dtype, copy=copy)
334+
343335
return dtl.DatetimeLikeArrayMixin.astype(self, dtype, copy=copy)
344336

345337
def __iter__(self):

pandas/core/dtypes/cast.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939
from pandas.core.dtypes.common import (
4040
DT64NS_DTYPE,
41-
INT64_DTYPE,
4241
POSSIBLY_CAST_DTYPES,
4342
TD64NS_DTYPE,
4443
ensure_int8,
@@ -952,6 +951,39 @@ def coerce_indexer_dtype(indexer, categories):
952951
return ensure_int64(indexer)
953952

954953

954+
def astype_td64_unit_conversion(
955+
values: np.ndarray, dtype: np.dtype, copy: bool
956+
) -> np.ndarray:
957+
"""
958+
By pandas convention, converting to non-nano timedelta64
959+
returns an int64-dtyped array with ints representing multiples
960+
of the desired timedelta unit. This is essentially division.
961+
962+
Parameters
963+
----------
964+
values : np.ndarray[timedelta64[ns]]
965+
dtype : np.dtype
966+
timedelta64 with unit not-necessarily nano
967+
copy : bool
968+
969+
Returns
970+
-------
971+
np.ndarray
972+
"""
973+
if is_dtype_equal(values.dtype, dtype):
974+
if copy:
975+
return values.copy()
976+
return values
977+
978+
# otherwise we are converting to non-nano
979+
result = values.astype(dtype, copy=False) # avoid double-copying
980+
result = result.astype(np.float64)
981+
982+
mask = isna(values)
983+
np.putmask(result, mask, np.nan)
984+
return result
985+
986+
955987
def astype_nansafe(
956988
arr, dtype: DtypeObj, copy: bool = True, skipna: bool = False
957989
) -> ArrayLike:
@@ -1015,17 +1047,8 @@ def astype_nansafe(
10151047
raise ValueError("Cannot convert NaT values to integer")
10161048
return arr.view(dtype)
10171049

1018-
if dtype not in [INT64_DTYPE, TD64NS_DTYPE]:
1019-
1020-
# allow frequency conversions
1021-
# we return a float here!
1022-
if dtype.kind == "m":
1023-
mask = isna(arr)
1024-
result = arr.astype(dtype).astype(np.float64)
1025-
result[mask] = np.nan
1026-
return result
1027-
elif dtype == TD64NS_DTYPE:
1028-
return arr.astype(TD64NS_DTYPE, copy=copy)
1050+
elif dtype.kind == "m":
1051+
return astype_td64_unit_conversion(arr, dtype, copy=copy)
10291052

10301053
raise TypeError(f"cannot astype a timedelta from [{arr.dtype}] to [{dtype}]")
10311054

pandas/core/indexes/timedeltas.py

+3-26
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,14 @@
44
from pandas._libs.tslibs import Timedelta, to_offset
55
from pandas._typing import DtypeObj
66
from pandas.errors import InvalidIndexError
7-
from pandas.util._decorators import doc
8-
9-
from pandas.core.dtypes.common import (
10-
TD64NS_DTYPE,
11-
is_scalar,
12-
is_timedelta64_dtype,
13-
is_timedelta64_ns_dtype,
14-
pandas_dtype,
15-
)
7+
8+
from pandas.core.dtypes.common import TD64NS_DTYPE, is_scalar, is_timedelta64_dtype
169

1710
from pandas.core.arrays import datetimelike as dtl
1811
from pandas.core.arrays.timedeltas import TimedeltaArray
1912
import pandas.core.common as com
2013
from pandas.core.indexes.base import Index, maybe_extract_name
21-
from pandas.core.indexes.datetimelike import (
22-
DatetimeIndexOpsMixin,
23-
DatetimeTimedeltaMixin,
24-
)
14+
from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin
2515
from pandas.core.indexes.extension import inherit_names
2616

2717

@@ -159,19 +149,6 @@ def __new__(
159149

160150
# -------------------------------------------------------------------
161151

162-
@doc(Index.astype)
163-
def astype(self, dtype, copy: bool = True):
164-
dtype = pandas_dtype(dtype)
165-
if is_timedelta64_dtype(dtype) and not is_timedelta64_ns_dtype(dtype):
166-
# Have to repeat the check for 'timedelta64' (not ns) dtype
167-
# so that we can return a numeric index, since pandas will return
168-
# a TimedeltaIndex when dtype='timedelta'
169-
result = self._data.astype(dtype, copy=copy)
170-
if self.hasnans:
171-
return Index(result, name=self.name)
172-
return Index(result.astype("i8"), name=self.name)
173-
return DatetimeIndexOpsMixin.astype(self, dtype, copy=copy)
174-
175152
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
176153
"""
177154
Can we compare values of the given dtype to our own?

pandas/tests/indexes/timedeltas/test_timedelta.py

+15
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,21 @@ def test_fields(self):
189189
rng.name = "name"
190190
assert rng.days.name == "name"
191191

192+
def test_freq_conversion_always_floating(self):
193+
# even if we have no NaTs, we get back float64; this matches TDA and Series
194+
tdi = timedelta_range("1 Day", periods=30)
195+
196+
res = tdi.astype("m8[s]")
197+
expected = Index((tdi.view("i8") / 10 ** 9).astype(np.float64))
198+
tm.assert_index_equal(res, expected)
199+
200+
# check this matches Series and TimedeltaArray
201+
res = tdi._data.astype("m8[s]")
202+
tm.assert_numpy_array_equal(res, expected._values)
203+
204+
res = tdi.to_series().astype("m8[s]")
205+
tm.assert_numpy_array_equal(res._values, expected._values)
206+
192207
def test_freq_conversion(self):
193208

194209
# doc example

0 commit comments

Comments
 (0)