Skip to content

Commit 0805043

Browse files
authored
REF: collect dt64<->dt64tz astype in dtypes.cast (#38662)
1 parent 573caff commit 0805043

File tree

3 files changed

+58
-27
lines changed

3 files changed

+58
-27
lines changed

pandas/core/arrays/datetimes.py

+4-18
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
)
2525
from pandas.errors import PerformanceWarning
2626

27+
from pandas.core.dtypes.cast import astype_dt64_to_dt64tz
2728
from pandas.core.dtypes.common import (
2829
DT64NS_DTYPE,
2930
INT64_DTYPE,
3031
is_bool_dtype,
3132
is_categorical_dtype,
3233
is_datetime64_any_dtype,
3334
is_datetime64_dtype,
35+
is_datetime64_ns_dtype,
3436
is_datetime64tz_dtype,
3537
is_dtype_equal,
3638
is_extension_array_dtype,
@@ -591,24 +593,8 @@ def astype(self, dtype, copy=True):
591593
return self.copy()
592594
return self
593595

594-
elif is_datetime64tz_dtype(dtype) and self.tz is None:
595-
# FIXME: GH#33401 this does not match Series behavior
596-
return self.tz_localize(dtype.tz)
597-
598-
elif is_datetime64tz_dtype(dtype):
599-
# GH#18951: datetime64_ns dtype but not equal means different tz
600-
result = self.tz_convert(dtype.tz)
601-
if copy:
602-
result = result.copy()
603-
return result
604-
605-
elif dtype == "M8[ns]":
606-
# we must have self.tz is None, otherwise we would have gone through
607-
# the is_dtype_equal branch above.
608-
result = self.tz_convert("UTC").tz_localize(None)
609-
if copy:
610-
result = result.copy()
611-
return result
596+
elif is_datetime64_ns_dtype(dtype):
597+
return astype_dt64_to_dt64tz(self, dtype, copy, via_utc=False)
612598

613599
elif is_period_dtype(dtype):
614600
return self.to_period(freq=dtype.freq)

pandas/core/dtypes/cast.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Tuple,
1616
Type,
1717
Union,
18+
cast,
1819
)
1920
import warnings
2021

@@ -85,7 +86,7 @@
8586

8687
if TYPE_CHECKING:
8788
from pandas import Series
88-
from pandas.core.arrays import ExtensionArray
89+
from pandas.core.arrays import DatetimeArray, ExtensionArray
8990

9091
_int8_max = np.iinfo(np.int8).max
9192
_int16_max = np.iinfo(np.int16).max
@@ -920,6 +921,56 @@ def coerce_indexer_dtype(indexer, categories):
920921
return ensure_int64(indexer)
921922

922923

924+
def astype_dt64_to_dt64tz(
925+
values: ArrayLike, dtype: DtypeObj, copy: bool, via_utc: bool = False
926+
) -> "DatetimeArray":
927+
# GH#33401 we have inconsistent behaviors between
928+
# Datetimeindex[naive].astype(tzaware)
929+
# Series[dt64].astype(tzaware)
930+
# This collects them in one place to prevent further fragmentation.
931+
932+
from pandas.core.construction import ensure_wrapped_if_datetimelike
933+
934+
values = ensure_wrapped_if_datetimelike(values)
935+
values = cast("DatetimeArray", values)
936+
aware = isinstance(dtype, DatetimeTZDtype)
937+
938+
if via_utc:
939+
# Series.astype behavior
940+
assert values.tz is None and aware # caller is responsible for checking this
941+
dtype = cast(DatetimeTZDtype, dtype)
942+
943+
if copy:
944+
# this should be the only copy
945+
values = values.copy()
946+
# FIXME: GH#33401 this doesn't match DatetimeArray.astype, which
947+
# goes through the `not via_utc` path
948+
return values.tz_localize("UTC").tz_convert(dtype.tz)
949+
950+
else:
951+
# DatetimeArray/DatetimeIndex.astype behavior
952+
953+
if values.tz is None and aware:
954+
dtype = cast(DatetimeTZDtype, dtype)
955+
return values.tz_localize(dtype.tz)
956+
957+
elif aware:
958+
# GH#18951: datetime64_tz dtype but not equal means different tz
959+
dtype = cast(DatetimeTZDtype, dtype)
960+
result = values.tz_convert(dtype.tz)
961+
if copy:
962+
result = result.copy()
963+
return result
964+
965+
elif values.tz is not None and not aware:
966+
result = values.tz_convert("UTC").tz_localize(None)
967+
if copy:
968+
result = result.copy()
969+
return result
970+
971+
raise NotImplementedError("dtype_equal case should be handled elsewhere")
972+
973+
923974
def astype_td64_unit_conversion(
924975
values: np.ndarray, dtype: np.dtype, copy: bool
925976
) -> np.ndarray:

pandas/core/internals/blocks.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pandas.util._validators import validate_bool_kwarg
2323

2424
from pandas.core.dtypes.cast import (
25+
astype_dt64_to_dt64tz,
2526
astype_nansafe,
2627
convert_scalar_for_putitemlike,
2728
find_common_type,
@@ -649,14 +650,7 @@ def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike:
649650
values = self.values
650651

651652
if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype):
652-
# if we are passed a datetime64[ns, tz]
653-
if copy:
654-
# this should be the only copy
655-
values = values.copy()
656-
# i.e. values.tz_localize("UTC").tz_convert(dtype.tz)
657-
# FIXME: GH#33401 this doesn't match DatetimeArray.astype, which
658-
# would be self.array_values().tz_localize(dtype.tz)
659-
return DatetimeArray._simple_new(values.view("i8"), dtype=dtype)
653+
return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True)
660654

661655
if is_dtype_equal(values.dtype, dtype):
662656
if copy:

0 commit comments

Comments
 (0)