Skip to content

REF: simplify pytables _set_tz #57654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 22 additions & 49 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Final,
Literal,
cast,
overload,
)
import warnings

Expand Down Expand Up @@ -85,6 +84,7 @@
DatetimeArray,
PeriodArray,
)
from pandas.core.arrays.datetimes import tz_to_dtype
import pandas.core.common as com
from pandas.core.computation.pytables import (
PyTablesExpr,
Expand Down Expand Up @@ -2170,7 +2170,12 @@ def convert(
if "freq" in kwargs:
kwargs["freq"] = None
new_pd_index = factory(values, **kwargs)
final_pd_index = _set_tz(new_pd_index, self.tz)

final_pd_index: Index
if self.tz is not None and isinstance(new_pd_index, DatetimeIndex):
final_pd_index = new_pd_index.tz_localize("UTC").tz_convert(self.tz)
else:
final_pd_index = new_pd_index
return final_pd_index, final_pd_index

def take_data(self):
Expand Down Expand Up @@ -2567,7 +2572,7 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str):
# reverse converts
if dtype.startswith("datetime64"):
# recreate with tz if indicated
converted = _set_tz(converted, tz, coerce=True)
converted = _set_tz(converted, tz)

elif dtype == "timedelta64":
converted = np.asarray(converted, dtype="m8[ns]")
Expand Down Expand Up @@ -2948,7 +2953,7 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None
if dtype and dtype.startswith("datetime64"):
# reconstruct a timezone if indicated
tz = getattr(attrs, "tz", None)
ret = _set_tz(ret, tz, coerce=True)
ret = _set_tz(ret, tz)

elif dtype == "timedelta64":
ret = np.asarray(ret, dtype="m8[ns]")
Expand Down Expand Up @@ -4298,7 +4303,8 @@ def read_column(
encoding=self.encoding,
errors=self.errors,
)
return Series(_set_tz(col_values[1], a.tz), name=column, copy=False)
cvs = col_values[1]
return Series(cvs, name=column, copy=False)

raise KeyError(f"column [{column}] not found in the table")

Expand Down Expand Up @@ -4637,7 +4643,7 @@ def read(
if values.ndim == 1 and isinstance(values, np.ndarray):
values = values.reshape((1, values.shape[0]))

if isinstance(values, np.ndarray):
if isinstance(values, (np.ndarray, DatetimeArray)):
df = DataFrame(values.T, columns=cols_, index=index_, copy=False)
elif isinstance(values, Index):
df = DataFrame(values, columns=cols_, index=index_)
Expand Down Expand Up @@ -4873,54 +4879,21 @@ def _get_tz(tz: tzinfo) -> str | tzinfo:
return zone


@overload
def _set_tz(
values: np.ndarray | Index, tz: str | tzinfo, coerce: bool = False
) -> DatetimeIndex:
...


@overload
def _set_tz(values: np.ndarray | Index, tz: None, coerce: bool = False) -> np.ndarray:
...


def _set_tz(
values: np.ndarray | Index, tz: str | tzinfo | None, coerce: bool = False
) -> np.ndarray | DatetimeIndex:
def _set_tz(values: npt.NDArray[np.int64], tz: str | tzinfo | None) -> DatetimeArray:
"""
coerce the values to a DatetimeIndex if tz is set
preserve the input shape if possible
Coerce the values to a DatetimeArray with appropriate tz.

Parameters
----------
values : ndarray or Index
tz : str or tzinfo
coerce : if we do not have a passed timezone, coerce to M8[ns] ndarray
values : ndarray[int64]
tz : str, tzinfo, or None
"""
if isinstance(values, DatetimeIndex):
# If values is tzaware, the tz gets dropped in the values.ravel()
# call below (which returns an ndarray). So we are only non-lossy
# if `tz` matches `values.tz`.
assert values.tz is None or values.tz == tz
if values.tz is not None:
return values

if tz is not None:
if isinstance(values, DatetimeIndex):
name = values.name
else:
name = None
values = values.ravel()

values = DatetimeIndex(values, name=name)
values = values.tz_localize("UTC").tz_convert(tz)
elif coerce:
values = np.asarray(values, dtype="M8[ns]")

# error: Incompatible return value type (got "Union[ndarray, Index]",
# expected "Union[ndarray, DatetimeIndex]")
return values # type: ignore[return-value]
assert values.dtype == "i8", values.dtype
# Argument "tz" to "tz_to_dtype" has incompatible type "str | tzinfo | None";
# expected "tzinfo"
dtype = tz_to_dtype(tz=tz, unit="ns") # type: ignore[arg-type]
dta = DatetimeArray._from_sequence(values, dtype=dtype)
return dta


def _convert_index(name: str, index: Index, encoding: str, errors: str) -> IndexCol:
Expand Down