Skip to content

Commit aabc35a

Browse files
authored
REF: simplify pytables _set_tz (#57654)
1 parent 015f8d3 commit aabc35a

File tree

1 file changed

+22
-49
lines changed

1 file changed

+22
-49
lines changed

pandas/io/pytables.py

+22-49
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Final,
2222
Literal,
2323
cast,
24-
overload,
2524
)
2625
import warnings
2726

@@ -85,6 +84,7 @@
8584
DatetimeArray,
8685
PeriodArray,
8786
)
87+
from pandas.core.arrays.datetimes import tz_to_dtype
8888
import pandas.core.common as com
8989
from pandas.core.computation.pytables import (
9090
PyTablesExpr,
@@ -2170,7 +2170,12 @@ def convert(
21702170
if "freq" in kwargs:
21712171
kwargs["freq"] = None
21722172
new_pd_index = factory(values, **kwargs)
2173-
final_pd_index = _set_tz(new_pd_index, self.tz)
2173+
2174+
final_pd_index: Index
2175+
if self.tz is not None and isinstance(new_pd_index, DatetimeIndex):
2176+
final_pd_index = new_pd_index.tz_localize("UTC").tz_convert(self.tz)
2177+
else:
2178+
final_pd_index = new_pd_index
21742179
return final_pd_index, final_pd_index
21752180

21762181
def take_data(self):
@@ -2567,7 +2572,7 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str):
25672572
# reverse converts
25682573
if dtype.startswith("datetime64"):
25692574
# recreate with tz if indicated
2570-
converted = _set_tz(converted, tz, coerce=True)
2575+
converted = _set_tz(converted, tz)
25712576

25722577
elif dtype == "timedelta64":
25732578
converted = np.asarray(converted, dtype="m8[ns]")
@@ -2948,7 +2953,7 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None
29482953
if dtype and dtype.startswith("datetime64"):
29492954
# reconstruct a timezone if indicated
29502955
tz = getattr(attrs, "tz", None)
2951-
ret = _set_tz(ret, tz, coerce=True)
2956+
ret = _set_tz(ret, tz)
29522957

29532958
elif dtype == "timedelta64":
29542959
ret = np.asarray(ret, dtype="m8[ns]")
@@ -4298,7 +4303,8 @@ def read_column(
42984303
encoding=self.encoding,
42994304
errors=self.errors,
43004305
)
4301-
return Series(_set_tz(col_values[1], a.tz), name=column, copy=False)
4306+
cvs = col_values[1]
4307+
return Series(cvs, name=column, copy=False)
43024308

43034309
raise KeyError(f"column [{column}] not found in the table")
43044310

@@ -4637,7 +4643,7 @@ def read(
46374643
if values.ndim == 1 and isinstance(values, np.ndarray):
46384644
values = values.reshape((1, values.shape[0]))
46394645

4640-
if isinstance(values, np.ndarray):
4646+
if isinstance(values, (np.ndarray, DatetimeArray)):
46414647
df = DataFrame(values.T, columns=cols_, index=index_, copy=False)
46424648
elif isinstance(values, Index):
46434649
df = DataFrame(values, columns=cols_, index=index_)
@@ -4873,54 +4879,21 @@ def _get_tz(tz: tzinfo) -> str | tzinfo:
48734879
return zone
48744880

48754881

4876-
@overload
4877-
def _set_tz(
4878-
values: np.ndarray | Index, tz: str | tzinfo, coerce: bool = False
4879-
) -> DatetimeIndex:
4880-
...
4881-
4882-
4883-
@overload
4884-
def _set_tz(values: np.ndarray | Index, tz: None, coerce: bool = False) -> np.ndarray:
4885-
...
4886-
4887-
4888-
def _set_tz(
4889-
values: np.ndarray | Index, tz: str | tzinfo | None, coerce: bool = False
4890-
) -> np.ndarray | DatetimeIndex:
4882+
def _set_tz(values: npt.NDArray[np.int64], tz: str | tzinfo | None) -> DatetimeArray:
48914883
"""
4892-
coerce the values to a DatetimeIndex if tz is set
4893-
preserve the input shape if possible
4884+
Coerce the values to a DatetimeArray with appropriate tz.
48944885
48954886
Parameters
48964887
----------
4897-
values : ndarray or Index
4898-
tz : str or tzinfo
4899-
coerce : if we do not have a passed timezone, coerce to M8[ns] ndarray
4888+
values : ndarray[int64]
4889+
tz : str, tzinfo, or None
49004890
"""
4901-
if isinstance(values, DatetimeIndex):
4902-
# If values is tzaware, the tz gets dropped in the values.ravel()
4903-
# call below (which returns an ndarray). So we are only non-lossy
4904-
# if `tz` matches `values.tz`.
4905-
assert values.tz is None or values.tz == tz
4906-
if values.tz is not None:
4907-
return values
4908-
4909-
if tz is not None:
4910-
if isinstance(values, DatetimeIndex):
4911-
name = values.name
4912-
else:
4913-
name = None
4914-
values = values.ravel()
4915-
4916-
values = DatetimeIndex(values, name=name)
4917-
values = values.tz_localize("UTC").tz_convert(tz)
4918-
elif coerce:
4919-
values = np.asarray(values, dtype="M8[ns]")
4920-
4921-
# error: Incompatible return value type (got "Union[ndarray, Index]",
4922-
# expected "Union[ndarray, DatetimeIndex]")
4923-
return values # type: ignore[return-value]
4891+
assert values.dtype == "i8", values.dtype
4892+
# Argument "tz" to "tz_to_dtype" has incompatible type "str | tzinfo | None";
4893+
# expected "tzinfo"
4894+
dtype = tz_to_dtype(tz=tz, unit="ns") # type: ignore[arg-type]
4895+
dta = DatetimeArray._from_sequence(values, dtype=dtype)
4896+
return dta
49244897

49254898

49264899
def _convert_index(name: str, index: Index, encoding: str, errors: str) -> IndexCol:

0 commit comments

Comments
 (0)