Skip to content

Commit 55dc324

Browse files
API: .astype td64->td64 give requested dtype (pandas-dev#48963)
* API: .astype td64->td64 give requested dtype * fix missing import * Update doc/source/whatsnew/v1.6.0.rst Co-authored-by: Matthew Roeschke <[email protected]> Co-authored-by: Matthew Roeschke <[email protected]>
1 parent ac05d29 commit 55dc324

File tree

6 files changed

+53
-9
lines changed

6 files changed

+53
-9
lines changed

doc/source/whatsnew/v1.6.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ Other API changes
119119
- Passing ``nanoseconds`` greater than 999 or less than 0 in :class:`Timestamp` now raises a ``ValueError`` (:issue:`48538`, :issue:`48255`)
120120
- :func:`read_csv`: specifying an incorrect number of columns with ``index_col`` of now raises ``ParserError`` instead of ``IndexError`` when using the c parser.
121121
- :meth:`DataFrame.astype`, :meth:`Series.astype`, and :meth:`DatetimeIndex.astype` casting datetime64 data to any of "datetime64[s]", "datetime64[ms]", "datetime64[us]" will return an object with the given resolution instead of coercing back to "datetime64[ns]" (:issue:`48928`)
122+
- :meth:`DataFrame.astype`, :meth:`Series.astype`, and :meth:`DatetimeIndex.astype` casting timedelta64 data to any of "timedelta64[s]", "timedelta64[ms]", "timedelta64[us]" will return an object with the given resolution instead of coercing to "float64" dtype (:issue:`48963`)
122123
-
123124

124125
.. ---------------------------------------------------------------------------

pandas/core/arrays/timedeltas.py

+14
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
Tick,
2121
Timedelta,
2222
astype_overflowsafe,
23+
get_unit_from_dtype,
2324
iNaT,
25+
is_supported_unit,
2426
periods_per_second,
2527
to_offset,
2628
)
@@ -308,6 +310,18 @@ def astype(self, dtype, copy: bool = True):
308310
dtype = pandas_dtype(dtype)
309311

310312
if dtype.kind == "m":
313+
if dtype == self.dtype:
314+
if copy:
315+
return self.copy()
316+
return self
317+
318+
if is_supported_unit(get_unit_from_dtype(dtype)):
319+
# unit conversion e.g. timedelta64[s]
320+
res_values = astype_overflowsafe(self._ndarray, dtype, copy=False)
321+
return type(self)._simple_new(
322+
res_values, dtype=res_values.dtype, freq=self.freq
323+
)
324+
311325
return astype_td64_unit_conversion(self._ndarray, dtype, copy=copy)
312326

313327
return dtl.DatetimeLikeArrayMixin.astype(self, dtype, copy=copy)

pandas/core/dtypes/astype.py

+8
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def astype_nansafe(
136136
return arr.view(dtype)
137137

138138
elif dtype.kind == "m":
139+
# TODO(2.0): change to use the same logic as TDA.astype, i.e.
140+
# giving the requested dtype for supported units (s, ms, us, ns)
141+
# and doing the old convert-to-float behavior otherwise.
142+
if is_supported_unit(get_unit_from_dtype(arr.dtype)):
143+
from pandas.core.construction import ensure_wrapped_if_datetimelike
144+
145+
arr = ensure_wrapped_if_datetimelike(arr)
146+
return arr.astype(dtype, copy=copy)
139147
return astype_td64_unit_conversion(arr, dtype, copy=copy)
140148

141149
raise TypeError(f"cannot astype a timedelta from [{arr.dtype}] to [{dtype}]")

pandas/tests/frame/methods/test_astype.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -480,12 +480,19 @@ def test_astype_to_timedelta_unit_ns(self, unit):
480480
@pytest.mark.parametrize("unit", ["us", "ms", "s", "h", "m", "D"])
481481
def test_astype_to_timedelta_unit(self, unit):
482482
# coerce to float
483-
# GH#19223
483+
# GH#19223 until 2.0 used to coerce to float
484484
dtype = f"m8[{unit}]"
485485
arr = np.array([[1, 2, 3]], dtype=dtype)
486486
df = DataFrame(arr)
487487
result = df.astype(dtype)
488-
expected = DataFrame(df.values.astype(dtype).astype(float))
488+
489+
if unit in ["m", "h", "D"]:
490+
# We don't support these, so we use the old logic to convert to float
491+
expected = DataFrame(df.values.astype(dtype).astype(float))
492+
else:
493+
tda = pd.core.arrays.TimedeltaArray._simple_new(arr, dtype=arr.dtype)
494+
expected = DataFrame(tda)
495+
assert (expected.dtypes == dtype).all()
489496

490497
tm.assert_frame_equal(result, expected)
491498

pandas/tests/indexes/timedeltas/test_timedelta.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
timedelta_range,
1414
)
1515
import pandas._testing as tm
16+
from pandas.core.arrays import TimedeltaArray
1617
from pandas.core.indexes.api import Int64Index
1718
from pandas.tests.indexes.datetimelike import DatetimeLike
1819

@@ -101,19 +102,26 @@ def test_fields(self):
101102
assert rng.days.name == "name"
102103

103104
def test_freq_conversion_always_floating(self):
104-
# even if we have no NaTs, we get back float64; this matches TDA and Series
105+
# pre-2.0 td64 astype converted to float64. now for supported units
106+
# (s, ms, us, ns) this converts to the requested dtype.
107+
# This matches TDA and Series
105108
tdi = timedelta_range("1 Day", periods=30)
106109

107110
res = tdi.astype("m8[s]")
108-
expected = Index((tdi.view("i8") / 10**9).astype(np.float64))
111+
exp_values = np.asarray(tdi).astype("m8[s]")
112+
exp_tda = TimedeltaArray._simple_new(
113+
exp_values, dtype=exp_values.dtype, freq=tdi.freq
114+
)
115+
expected = Index(exp_tda)
116+
assert expected.dtype == "m8[s]"
109117
tm.assert_index_equal(res, expected)
110118

111119
# check this matches Series and TimedeltaArray
112120
res = tdi._data.astype("m8[s]")
113-
tm.assert_numpy_array_equal(res, expected._values)
121+
tm.assert_equal(res, expected._values)
114122

115123
res = tdi.to_series().astype("m8[s]")
116-
tm.assert_numpy_array_equal(res._values, expected._values)
124+
tm.assert_equal(res._values, expected._values._with_freq(None))
117125

118126
def test_freq_conversion(self, index_or_series):
119127

@@ -131,6 +139,8 @@ def test_freq_conversion(self, index_or_series):
131139
)
132140
tm.assert_equal(result, expected)
133141

142+
# We don't support "D" reso, so we use the pre-2.0 behavior
143+
# casting to float64
134144
result = td.astype("timedelta64[D]")
135145
expected = index_or_series([31, 31, 31, np.nan])
136146
tm.assert_equal(result, expected)
@@ -141,5 +151,9 @@ def test_freq_conversion(self, index_or_series):
141151
)
142152
tm.assert_equal(result, expected)
143153

154+
exp_values = np.asarray(td).astype("m8[s]")
155+
exp_tda = TimedeltaArray._simple_new(exp_values, dtype=exp_values.dtype)
156+
expected = index_or_series(exp_tda)
157+
assert expected.dtype == "m8[s]"
144158
result = td.astype("timedelta64[s]")
145159
tm.assert_equal(result, expected)

pandas/tests/indexing/test_loc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -649,14 +649,14 @@ def test_loc_setitem_consistency_slice_column_len(self):
649649
)
650650

651651
with tm.assert_produces_warning(None, match=msg):
652-
# timedelta64[s] -> float64, so this cannot be done inplace, so
652+
# timedelta64[m] -> float64, so this cannot be done inplace, so
653653
# no warning
654654
df.loc[:, ("Respondent", "Duration")] = df.loc[
655655
:, ("Respondent", "Duration")
656-
].astype("timedelta64[s]")
656+
].astype("timedelta64[m]")
657657

658658
expected = Series(
659-
[1380, 720, 840, 2160.0], index=df.index, name=("Respondent", "Duration")
659+
[23.0, 12.0, 14.0, 36.0], index=df.index, name=("Respondent", "Duration")
660660
)
661661
tm.assert_series_equal(df[("Respondent", "Duration")], expected)
662662

0 commit comments

Comments
 (0)