Skip to content

Commit ff8e88a

Browse files
authored
BUG: tz_localize with ArrowDtype (#52677)
* BUG: tz_localize with ArrowDtype * Add issue number * Add issue number * typing * Re-place * xfail for windows tzdata
1 parent 11d75d8 commit ff8e88a

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

doc/source/whatsnew/v2.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Bug fixes
3333
- Bug in :meth:`ArrowDtype.__from_arrow__` not respecting if dtype is explicitly given (:issue:`52533`)
3434
- Bug in :meth:`DataFrame.max` and related casting different :class:`Timestamp` resolutions always to nanoseconds (:issue:`52524`)
3535
- Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`)
36+
- Bug in :meth:`Series.dt.tz_localize` incorrectly localizing timestamps with :class:`ArrowDtype` (:issue:`52677`)
3637
- Fixed bug in :func:`merge` when merging with ``ArrowDtype`` one one and a NumPy dtype on the other side (:issue:`52406`)
3738
- Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`)
3839

pandas/core/arrays/arrow/array.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -2208,12 +2208,19 @@ def _dt_tz_localize(
22082208
):
22092209
if ambiguous != "raise":
22102210
raise NotImplementedError(f"{ambiguous=} is not supported")
2211-
if nonexistent != "raise":
2211+
nonexistent_pa = {
2212+
"raise": "raise",
2213+
"shift_backward": "earliest",
2214+
"shift_forward": "latest",
2215+
}.get(
2216+
nonexistent, None # type: ignore[arg-type]
2217+
)
2218+
if nonexistent_pa is None:
22122219
raise NotImplementedError(f"{nonexistent=} is not supported")
22132220
if tz is None:
2214-
new_type = pa.timestamp(self.dtype.pyarrow_dtype.unit)
2215-
return type(self)(self._pa_array.cast(new_type))
2216-
pa_tz = str(tz)
2217-
return type(self)(
2218-
self._pa_array.cast(pa.timestamp(self.dtype.pyarrow_dtype.unit, pa_tz))
2219-
)
2221+
result = self._pa_array.cast(pa.timestamp(self.dtype.pyarrow_dtype.unit))
2222+
else:
2223+
result = pc.assume_timezone(
2224+
self._pa_array, str(tz), ambiguous=ambiguous, nonexistent=nonexistent_pa
2225+
)
2226+
return type(self)(result)

pandas/tests/extension/test_arrow.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -2392,16 +2392,56 @@ def test_dt_tz_localize_none():
23922392

23932393

23942394
@pytest.mark.parametrize("unit", ["us", "ns"])
2395-
def test_dt_tz_localize(unit):
2395+
def test_dt_tz_localize(unit, request):
2396+
if is_platform_windows() and is_ci_environment():
2397+
request.node.add_marker(
2398+
pytest.mark.xfail(
2399+
raises=pa.ArrowInvalid,
2400+
reason=(
2401+
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
2402+
"on CI to path to the tzdata for pyarrow."
2403+
),
2404+
)
2405+
)
23962406
ser = pd.Series(
23972407
[datetime(year=2023, month=1, day=2, hour=3), None],
23982408
dtype=ArrowDtype(pa.timestamp(unit)),
23992409
)
24002410
result = ser.dt.tz_localize("US/Pacific")
2401-
expected = pd.Series(
2402-
[datetime(year=2023, month=1, day=2, hour=3), None],
2403-
dtype=ArrowDtype(pa.timestamp(unit, "US/Pacific")),
2411+
exp_data = pa.array(
2412+
[datetime(year=2023, month=1, day=2, hour=3), None], type=pa.timestamp(unit)
2413+
)
2414+
exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
2415+
expected = pd.Series(ArrowExtensionArray(exp_data))
2416+
tm.assert_series_equal(result, expected)
2417+
2418+
2419+
@pytest.mark.parametrize(
2420+
"nonexistent, exp_date",
2421+
[
2422+
["shift_forward", datetime(year=2023, month=3, day=12, hour=3)],
2423+
["shift_backward", pd.Timestamp("2023-03-12 01:59:59.999999999")],
2424+
],
2425+
)
2426+
def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request):
2427+
if is_platform_windows() and is_ci_environment():
2428+
request.node.add_marker(
2429+
pytest.mark.xfail(
2430+
raises=pa.ArrowInvalid,
2431+
reason=(
2432+
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
2433+
"on CI to path to the tzdata for pyarrow."
2434+
),
2435+
)
2436+
)
2437+
ser = pd.Series(
2438+
[datetime(year=2023, month=3, day=12, hour=2, minute=30), None],
2439+
dtype=ArrowDtype(pa.timestamp("ns")),
24042440
)
2441+
result = ser.dt.tz_localize("US/Pacific", nonexistent=nonexistent)
2442+
exp_data = pa.array([exp_date, None], type=pa.timestamp("ns"))
2443+
exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
2444+
expected = pd.Series(ArrowExtensionArray(exp_data))
24052445
tm.assert_series_equal(result, expected)
24062446

24072447

0 commit comments

Comments
 (0)