Skip to content

Commit f197ca5

Browse files
authored
ENH: 2D compat for DTA tz_localize, to_period (#37950)
1 parent 8fd2d0c commit f197ca5

File tree

7 files changed

+58
-5
lines changed

7 files changed

+58
-5
lines changed

pandas/core/arrays/_mixins.py

+21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from functools import wraps
34
from typing import Any, Optional, Sequence, Type, TypeVar, Union
45

56
import numpy as np
@@ -27,6 +28,26 @@
2728
)
2829

2930

31+
def ravel_compat(meth):
32+
"""
33+
Decorator to ravel a 2D array before passing it to a cython operation,
34+
then reshape the result to our own shape.
35+
"""
36+
37+
@wraps(meth)
38+
def method(self, *args, **kwargs):
39+
if self.ndim == 1:
40+
return meth(self, *args, **kwargs)
41+
42+
flags = self._ndarray.flags
43+
flat = self.ravel("K")
44+
result = meth(flat, *args, **kwargs)
45+
order = "F" if flags.f_contiguous else "C"
46+
return result.reshape(self.shape, order=order)
47+
48+
return method
49+
50+
3051
class NDArrayBackedExtensionArray(ExtensionArray):
3152
"""
3253
ExtensionArray that is backed by a single NumPy ndarray.

pandas/core/arrays/datetimelike.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from pandas.core import nanops, ops
6565
from pandas.core.algorithms import checked_add_with_arr, isin, unique1d, value_counts
6666
from pandas.core.arraylike import OpsMixin
67-
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
67+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray, ravel_compat
6868
import pandas.core.common as com
6969
from pandas.core.construction import array, extract_array
7070
from pandas.core.indexers import check_array_indexer, check_setitem_lengths
@@ -679,6 +679,9 @@ def value_counts(self, dropna: bool = False):
679679
-------
680680
Series
681681
"""
682+
if self.ndim != 1:
683+
raise NotImplementedError
684+
682685
from pandas import Index, Series
683686

684687
if dropna:
@@ -694,6 +697,7 @@ def value_counts(self, dropna: bool = False):
694697
)
695698
return Series(result._values, index=index, name=result.name)
696699

700+
@ravel_compat
697701
def map(self, mapper):
698702
# TODO(GH-23179): Add ExtensionArray.map
699703
# Need to figure out if we want ExtensionArray.map first.
@@ -820,6 +824,9 @@ def freq(self, value):
820824
value = to_offset(value)
821825
self._validate_frequency(self, value)
822826

827+
if self.ndim > 1:
828+
raise ValueError("Cannot set freq with ndim > 1")
829+
823830
self._freq = value
824831

825832
@property
@@ -918,7 +925,7 @@ def _is_monotonic_decreasing(self) -> bool:
918925

919926
@property
920927
def _is_unique(self) -> bool:
921-
return len(unique1d(self.asi8)) == len(self)
928+
return len(unique1d(self.asi8.ravel("K"))) == self.size
922929

923930
# ------------------------------------------------------------------
924931
# Arithmetic Methods

pandas/core/arrays/datetimes.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -612,14 +612,15 @@ def astype(self, dtype, copy=True):
612612
# -----------------------------------------------------------------
613613
# Rendering Methods
614614

615+
@dtl.ravel_compat
615616
def _format_native_types(self, na_rep="NaT", date_format=None, **kwargs):
616617
from pandas.io.formats.format import get_format_datetime64_from_values
617618

618619
fmt = get_format_datetime64_from_values(self, date_format)
619620

620621
return tslib.format_array_from_datetime(
621-
self.asi8.ravel(), tz=self.tz, format=fmt, na_rep=na_rep
622-
).reshape(self.shape)
622+
self.asi8, tz=self.tz, format=fmt, na_rep=na_rep
623+
)
623624

624625
# -----------------------------------------------------------------
625626
# Comparison Methods
@@ -819,6 +820,7 @@ def tz_convert(self, tz):
819820
dtype = tz_to_dtype(tz)
820821
return self._simple_new(self.asi8, dtype=dtype, freq=self.freq)
821822

823+
@dtl.ravel_compat
822824
def tz_localize(self, tz, ambiguous="raise", nonexistent="raise"):
823825
"""
824826
Localize tz-naive Datetime Array/Index to tz-aware
@@ -1051,6 +1053,7 @@ def normalize(self):
10511053
new_values = normalize_i8_timestamps(self.asi8, self.tz)
10521054
return type(self)(new_values)._with_freq("infer").tz_localize(self.tz)
10531055

1056+
@dtl.ravel_compat
10541057
def to_period(self, freq=None):
10551058
"""
10561059
Cast to PeriodArray/Index at a particular frequency.

pandas/core/arrays/period.py

+1
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ def _formatter(self, boxed: bool = False):
562562
return str
563563
return "'{}'".format
564564

565+
@dtl.ravel_compat
565566
def _format_native_types(self, na_rep="NaT", date_format=None, **kwargs):
566567
"""
567568
actually format my specific types

pandas/core/arrays/timedeltas.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,12 @@ def _formatter(self, boxed=False):
400400

401401
return get_format_timedelta64(self, box=True)
402402

403+
@dtl.ravel_compat
403404
def _format_native_types(self, na_rep="NaT", date_format=None, **kwargs):
404405
from pandas.io.formats.format import get_format_timedelta64
405406

406407
formatter = get_format_timedelta64(self._data, na_rep)
407-
return np.array([formatter(x) for x in self._data.ravel()]).reshape(self.shape)
408+
return np.array([formatter(x) for x in self._data])
408409

409410
# ----------------------------------------------------------------
410411
# Arithmetic Methods

pandas/tests/arrays/test_datetimelike.py

+9
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,15 @@ def test_to_period(self, datetime_index, freqstr):
720720
# an EA-specific tm.assert_ function
721721
tm.assert_index_equal(pd.Index(result), pd.Index(expected))
722722

723+
def test_to_period_2d(self, arr1d):
724+
arr2d = arr1d.reshape(1, -1)
725+
726+
warn = None if arr1d.tz is None else UserWarning
727+
with tm.assert_produces_warning(warn):
728+
result = arr2d.to_period("D")
729+
expected = arr1d.to_period("D").reshape(1, -1)
730+
tm.assert_period_array_equal(result, expected)
731+
723732
@pytest.mark.parametrize("propname", pd.DatetimeIndex._bool_ops)
724733
def test_bool_properties(self, arr1d, propname):
725734
# in this case _bool_ops is just `is_leap_year`

pandas/tests/arrays/test_datetimes.py

+11
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,17 @@ def test_shift_requires_tzmatch(self):
449449
with pytest.raises(ValueError, match=msg):
450450
dta.shift(1, fill_value=fill_value)
451451

452+
def test_tz_localize_t2d(self):
453+
dti = pd.date_range("1994-05-12", periods=12, tz="US/Pacific")
454+
dta = dti._data.reshape(3, 4)
455+
result = dta.tz_localize(None)
456+
457+
expected = dta.ravel().tz_localize(None).reshape(dta.shape)
458+
tm.assert_datetime_array_equal(result, expected)
459+
460+
roundtrip = expected.tz_localize("US/Pacific")
461+
tm.assert_datetime_array_equal(roundtrip, dta)
462+
452463

453464
class TestSequenceToDT64NS:
454465
def test_tz_dtype_mismatch_raises(self):

0 commit comments

Comments
 (0)