From acbc1312412db400959c721fce8ee64ebcb704dc Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 27 Feb 2024 15:01:02 -0800 Subject: [PATCH] REF: simplify pytables _set_tz --- pandas/io/pytables.py | 71 ++++++++++++++----------------------------- 1 file changed, 22 insertions(+), 49 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index c7ee1cac2e14a..c835a7365d158 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -21,7 +21,6 @@ Final, Literal, cast, - overload, ) import warnings @@ -85,6 +84,7 @@ DatetimeArray, PeriodArray, ) +from pandas.core.arrays.datetimes import tz_to_dtype import pandas.core.common as com from pandas.core.computation.pytables import ( PyTablesExpr, @@ -2170,7 +2170,12 @@ def convert( if "freq" in kwargs: kwargs["freq"] = None new_pd_index = factory(values, **kwargs) - final_pd_index = _set_tz(new_pd_index, self.tz) + + final_pd_index: Index + if self.tz is not None and isinstance(new_pd_index, DatetimeIndex): + final_pd_index = new_pd_index.tz_localize("UTC").tz_convert(self.tz) + else: + final_pd_index = new_pd_index return final_pd_index, final_pd_index def take_data(self): @@ -2567,7 +2572,7 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): # reverse converts if dtype.startswith("datetime64"): # recreate with tz if indicated - converted = _set_tz(converted, tz, coerce=True) + converted = _set_tz(converted, tz) elif dtype == "timedelta64": converted = np.asarray(converted, dtype="m8[ns]") @@ -2948,7 +2953,7 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None if dtype and dtype.startswith("datetime64"): # reconstruct a timezone if indicated tz = getattr(attrs, "tz", None) - ret = _set_tz(ret, tz, coerce=True) + ret = _set_tz(ret, tz) elif dtype == "timedelta64": ret = np.asarray(ret, dtype="m8[ns]") @@ -4298,7 +4303,8 @@ def read_column( encoding=self.encoding, errors=self.errors, ) - return Series(_set_tz(col_values[1], a.tz), name=column, copy=False) + cvs = col_values[1] + return Series(cvs, name=column, copy=False) raise KeyError(f"column [{column}] not found in the table") @@ -4637,7 +4643,7 @@ def read( if values.ndim == 1 and isinstance(values, np.ndarray): values = values.reshape((1, values.shape[0])) - if isinstance(values, np.ndarray): + if isinstance(values, (np.ndarray, DatetimeArray)): df = DataFrame(values.T, columns=cols_, index=index_, copy=False) elif isinstance(values, Index): df = DataFrame(values, columns=cols_, index=index_) @@ -4873,54 +4879,21 @@ def _get_tz(tz: tzinfo) -> str | tzinfo: return zone -@overload -def _set_tz( - values: np.ndarray | Index, tz: str | tzinfo, coerce: bool = False -) -> DatetimeIndex: - ... - - -@overload -def _set_tz(values: np.ndarray | Index, tz: None, coerce: bool = False) -> np.ndarray: - ... - - -def _set_tz( - values: np.ndarray | Index, tz: str | tzinfo | None, coerce: bool = False -) -> np.ndarray | DatetimeIndex: +def _set_tz(values: npt.NDArray[np.int64], tz: str | tzinfo | None) -> DatetimeArray: """ - coerce the values to a DatetimeIndex if tz is set - preserve the input shape if possible + Coerce the values to a DatetimeArray with appropriate tz. Parameters ---------- - values : ndarray or Index - tz : str or tzinfo - coerce : if we do not have a passed timezone, coerce to M8[ns] ndarray + values : ndarray[int64] + tz : str, tzinfo, or None """ - if isinstance(values, DatetimeIndex): - # If values is tzaware, the tz gets dropped in the values.ravel() - # call below (which returns an ndarray). So we are only non-lossy - # if `tz` matches `values.tz`. - assert values.tz is None or values.tz == tz - if values.tz is not None: - return values - - if tz is not None: - if isinstance(values, DatetimeIndex): - name = values.name - else: - name = None - values = values.ravel() - - values = DatetimeIndex(values, name=name) - values = values.tz_localize("UTC").tz_convert(tz) - elif coerce: - values = np.asarray(values, dtype="M8[ns]") - - # error: Incompatible return value type (got "Union[ndarray, Index]", - # expected "Union[ndarray, DatetimeIndex]") - return values # type: ignore[return-value] + assert values.dtype == "i8", values.dtype + # Argument "tz" to "tz_to_dtype" has incompatible type "str | tzinfo | None"; + # expected "tzinfo" + dtype = tz_to_dtype(tz=tz, unit="ns") # type: ignore[arg-type] + dta = DatetimeArray._from_sequence(values, dtype=dtype) + return dta def _convert_index(name: str, index: Index, encoding: str, errors: str) -> IndexCol: