Skip to content

Commit c10de3a

Browse files
authored
Backport PR pandas-dev#52677 on branch 2.0.x (BUG: tz_localize with ArrowDtype) (pandas-dev#52762)
Backport PR pandas-dev#52677: BUG: tz_localize with ArrowDtype
1 parent ef66c8e commit c10de3a

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

doc/source/whatsnew/v2.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Bug fixes
3333
- Bug in :meth:`ArrowDtype.__from_arrow__` not respecting if dtype is explicitly given (:issue:`52533`)
3434
- Bug in :meth:`DataFrame.max` and related casting different :class:`Timestamp` resolutions always to nanoseconds (:issue:`52524`)
3535
- Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`)
36+
- Bug in :meth:`Series.dt.tz_localize` incorrectly localizing timestamps with :class:`ArrowDtype` (:issue:`52677`)
3637
- Fixed bug in :func:`merge` when merging with ``ArrowDtype`` one one and a NumPy dtype on the other side (:issue:`52406`)
3738
- Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`)
3839

pandas/core/arrays/arrow/array.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -2109,12 +2109,19 @@ def _dt_tz_localize(
21092109
):
21102110
if ambiguous != "raise":
21112111
raise NotImplementedError(f"{ambiguous=} is not supported")
2112-
if nonexistent != "raise":
2112+
nonexistent_pa = {
2113+
"raise": "raise",
2114+
"shift_backward": "earliest",
2115+
"shift_forward": "latest",
2116+
}.get(
2117+
nonexistent, None # type: ignore[arg-type]
2118+
)
2119+
if nonexistent_pa is None:
21132120
raise NotImplementedError(f"{nonexistent=} is not supported")
21142121
if tz is None:
2115-
new_type = pa.timestamp(self.dtype.pyarrow_dtype.unit)
2116-
return type(self)(self._data.cast(new_type))
2117-
pa_tz = str(tz)
2118-
return type(self)(
2119-
self._data.cast(pa.timestamp(self.dtype.pyarrow_dtype.unit, pa_tz))
2120-
)
2122+
result = self._data.cast(pa.timestamp(self.dtype.pyarrow_dtype.unit))
2123+
else:
2124+
result = pc.assume_timezone(
2125+
self._data, str(tz), ambiguous=ambiguous, nonexistent=nonexistent_pa
2126+
)
2127+
return type(self)(result)

pandas/tests/extension/test_arrow.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -2397,16 +2397,56 @@ def test_dt_tz_localize_none():
23972397

23982398

23992399
@pytest.mark.parametrize("unit", ["us", "ns"])
2400-
def test_dt_tz_localize(unit):
2400+
def test_dt_tz_localize(unit, request):
2401+
if is_platform_windows() and is_ci_environment():
2402+
request.node.add_marker(
2403+
pytest.mark.xfail(
2404+
raises=pa.ArrowInvalid,
2405+
reason=(
2406+
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
2407+
"on CI to path to the tzdata for pyarrow."
2408+
),
2409+
)
2410+
)
24012411
ser = pd.Series(
24022412
[datetime(year=2023, month=1, day=2, hour=3), None],
24032413
dtype=ArrowDtype(pa.timestamp(unit)),
24042414
)
24052415
result = ser.dt.tz_localize("US/Pacific")
2406-
expected = pd.Series(
2407-
[datetime(year=2023, month=1, day=2, hour=3), None],
2408-
dtype=ArrowDtype(pa.timestamp(unit, "US/Pacific")),
2416+
exp_data = pa.array(
2417+
[datetime(year=2023, month=1, day=2, hour=3), None], type=pa.timestamp(unit)
2418+
)
2419+
exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
2420+
expected = pd.Series(ArrowExtensionArray(exp_data))
2421+
tm.assert_series_equal(result, expected)
2422+
2423+
2424+
@pytest.mark.parametrize(
2425+
"nonexistent, exp_date",
2426+
[
2427+
["shift_forward", datetime(year=2023, month=3, day=12, hour=3)],
2428+
["shift_backward", pd.Timestamp("2023-03-12 01:59:59.999999999")],
2429+
],
2430+
)
2431+
def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request):
2432+
if is_platform_windows() and is_ci_environment():
2433+
request.node.add_marker(
2434+
pytest.mark.xfail(
2435+
raises=pa.ArrowInvalid,
2436+
reason=(
2437+
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
2438+
"on CI to path to the tzdata for pyarrow."
2439+
),
2440+
)
2441+
)
2442+
ser = pd.Series(
2443+
[datetime(year=2023, month=3, day=12, hour=2, minute=30), None],
2444+
dtype=ArrowDtype(pa.timestamp("ns")),
24092445
)
2446+
result = ser.dt.tz_localize("US/Pacific", nonexistent=nonexistent)
2447+
exp_data = pa.array([exp_date, None], type=pa.timestamp("ns"))
2448+
exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
2449+
expected = pd.Series(ArrowExtensionArray(exp_data))
24102450
tm.assert_series_equal(result, expected)
24112451

24122452

0 commit comments

Comments
 (0)