From 17a03556edd5041ec138eda8dcb062b98aa576d1 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 23 Dec 2020 11:06:02 -0800 Subject: [PATCH 1/2] REF: collect dt64<->dt64tz astype in dtypes.cast --- pandas/core/arrays/datetimes.py | 22 +++------------ pandas/core/dtypes/cast.py | 48 ++++++++++++++++++++++++++++++++- pandas/core/internals/blocks.py | 10 ++----- 3 files changed, 53 insertions(+), 27 deletions(-) diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index aa1d8f6254e2c..2b2be214428d2 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -24,6 +24,7 @@ ) from pandas.errors import PerformanceWarning +from pandas.core.dtypes.cast import astype_dt64_to_dt64tz from pandas.core.dtypes.common import ( DT64NS_DTYPE, INT64_DTYPE, @@ -31,6 +32,7 @@ is_categorical_dtype, is_datetime64_any_dtype, is_datetime64_dtype, + is_datetime64_ns_dtype, is_datetime64tz_dtype, is_dtype_equal, is_extension_array_dtype, @@ -591,24 +593,8 @@ def astype(self, dtype, copy=True): return self.copy() return self - elif is_datetime64tz_dtype(dtype) and self.tz is None: - # FIXME: GH#33401 this does not match Series behavior - return self.tz_localize(dtype.tz) - - elif is_datetime64tz_dtype(dtype): - # GH#18951: datetime64_ns dtype but not equal means different tz - result = self.tz_convert(dtype.tz) - if copy: - result = result.copy() - return result - - elif dtype == "M8[ns]": - # we must have self.tz is None, otherwise we would have gone through - # the is_dtype_equal branch above. - result = self.tz_convert("UTC").tz_localize(None) - if copy: - result = result.copy() - return result + elif is_datetime64_ns_dtype(dtype): + return astype_dt64_to_dt64tz(self, dtype, copy, via_utc=False) elif is_period_dtype(dtype): return self.to_period(freq=dtype.freq) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index d1c16de05ce55..26a1ad7a9d8b2 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -85,7 +85,7 @@ if TYPE_CHECKING: from pandas import Series - from pandas.core.arrays import ExtensionArray + from pandas.core.arrays import DatetimeArray, ExtensionArray _int8_max = np.iinfo(np.int8).max _int16_max = np.iinfo(np.int16).max @@ -912,6 +912,52 @@ def coerce_indexer_dtype(indexer, categories): return ensure_int64(indexer) +def astype_dt64_to_dt64tz( + values: ArrayLike, dtype: DtypeObj, copy: bool, via_utc: bool = False +) -> "DatetimeArray": + # GH#33401 we have inconsistent behaviors between + # Datetimeindex[naive].astype(tzaware) + # Series[dt64].astype(tzaware) + # This collects them in one place to prevent further fragmentation. + + from pandas.core.construction import ensure_wrapped_if_datetimelike + + values = ensure_wrapped_if_datetimelike(values) + aware = isinstance(dtype, DatetimeTZDtype) + + if via_utc: + # Series.astype behavior + assert values.tz is None and aware # caller is responsible for checking this + + if copy: + # this should be the only copy + values = values.copy() + # FIXME: GH#33401 this doesn't match DatetimeArray.astype, which + # goes through the `not via_utc` path + return values.tz_localize("UTC").tz_convert(dtype.tz) + + else: + # DatetimeArray/DatetimeIndex.astype behavior + + if values.tz is None and aware: + return values.tz_localize(dtype.tz) + + elif aware: + # GH#18951: datetime64_tz dtype but not equal means different tz + result = values.tz_convert(dtype.tz) + if copy: + result = result.copy() + return result + + elif values.tz is not None and not aware: + result = values.tz_convert("UTC").tz_localize(None) + if copy: + result = result.copy() + return result + + raise NotImplementedError("dtype_equal case should be handled elsewhere") + + def astype_td64_unit_conversion( values: np.ndarray, dtype: np.dtype, copy: bool ) -> np.ndarray: diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 8752224356f61..27acd720b6d71 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -22,6 +22,7 @@ from pandas.util._validators import validate_bool_kwarg from pandas.core.dtypes.cast import ( + astype_dt64_to_dt64tz, astype_nansafe, convert_scalar_for_putitemlike, find_common_type, @@ -649,14 +650,7 @@ def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike: values = self.values if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype): - # if we are passed a datetime64[ns, tz] - if copy: - # this should be the only copy - values = values.copy() - # i.e. values.tz_localize("UTC").tz_convert(dtype.tz) - # FIXME: GH#33401 this doesn't match DatetimeArray.astype, which - # would be self.array_values().tz_localize(dtype.tz) - return DatetimeArray._simple_new(values.view("i8"), dtype=dtype) + return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True) if is_dtype_equal(values.dtype, dtype): if copy: From a9bd4b6f7c97c662607423b52cf45e9e5ab46397 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 23 Dec 2020 11:20:51 -0800 Subject: [PATCH 2/2] mypy fixup --- pandas/core/dtypes/cast.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 26a1ad7a9d8b2..f78913ae7c807 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -16,6 +16,7 @@ Tuple, Type, Union, + cast, ) import numpy as np @@ -923,11 +924,13 @@ def astype_dt64_to_dt64tz( from pandas.core.construction import ensure_wrapped_if_datetimelike values = ensure_wrapped_if_datetimelike(values) + values = cast("DatetimeArray", values) aware = isinstance(dtype, DatetimeTZDtype) if via_utc: # Series.astype behavior assert values.tz is None and aware # caller is responsible for checking this + dtype = cast(DatetimeTZDtype, dtype) if copy: # this should be the only copy @@ -940,10 +943,12 @@ def astype_dt64_to_dt64tz( # DatetimeArray/DatetimeIndex.astype behavior if values.tz is None and aware: + dtype = cast(DatetimeTZDtype, dtype) return values.tz_localize(dtype.tz) elif aware: # GH#18951: datetime64_tz dtype but not equal means different tz + dtype = cast(DatetimeTZDtype, dtype) result = values.tz_convert(dtype.tz) if copy: result = result.copy()