Skip to content

Commit aeed91c

Browse files
authored
TYP: tighter typing in _apply_array (pandas-dev#56083)
* TYP: tighter typing in _apply_array * comment * mypy fixup * mypy fixup * mypy fixup * update run_stubtest
1 parent 24fdde6 commit aeed91c

File tree

4 files changed

+26
-52
lines changed

4 files changed

+26
-52
lines changed

pandas/_libs/tslibs/offsets.pyi

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ApplyTypeError(TypeError): ...
3333

3434
class BaseOffset:
3535
n: int
36+
normalize: bool
3637
def __init__(self, n: int = ..., normalize: bool = ...) -> None: ...
3738
def __eq__(self, other) -> bool: ...
3839
def __ne__(self, other) -> bool: ...
@@ -85,7 +86,7 @@ class BaseOffset:
8586
@property
8687
def freqstr(self) -> str: ...
8788
def _apply(self, other): ...
88-
def _apply_array(self, dtarr) -> None: ...
89+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray: ...
8990
def rollback(self, dt: datetime) -> datetime: ...
9091
def rollforward(self, dt: datetime) -> datetime: ...
9192
def is_on_offset(self, dt: datetime) -> bool: ...

pandas/_libs/tslibs/offsets.pyx

+10-43
Original file line numberDiff line numberDiff line change
@@ -110,33 +110,6 @@ cdef bint _is_normalized(datetime dt):
110110
return True
111111

112112

113-
def apply_wrapper_core(func, self, other) -> ndarray:
114-
result = func(self, other)
115-
result = np.asarray(result)
116-
117-
if self.normalize:
118-
# TODO: Avoid circular/runtime import
119-
from .vectorized import normalize_i8_timestamps
120-
reso = get_unit_from_dtype(other.dtype)
121-
result = normalize_i8_timestamps(result.view("i8"), None, reso=reso)
122-
123-
return result
124-
125-
126-
def apply_array_wraps(func):
127-
# Note: normally we would use `@functools.wraps(func)`, but this does
128-
# not play nicely with cython class methods
129-
def wrapper(self, other) -> np.ndarray:
130-
# other is a DatetimeArray
131-
result = apply_wrapper_core(func, self, other)
132-
return result
133-
134-
# do @functools.wraps(func) manually since it doesn't work on cdef funcs
135-
wrapper.__name__ = func.__name__
136-
wrapper.__doc__ = func.__doc__
137-
return wrapper
138-
139-
140113
def apply_wraps(func):
141114
# Note: normally we would use `@functools.wraps(func)`, but this does
142115
# not play nicely with cython class methods
@@ -644,8 +617,9 @@ cdef class BaseOffset:
644617
def _apply(self, other):
645618
raise NotImplementedError("implemented by subclasses")
646619

647-
@apply_array_wraps
648-
def _apply_array(self, dtarr):
620+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray:
621+
# NB: _apply_array does not handle respecting `self.normalize`, the
622+
# caller (DatetimeArray) handles that in post-processing.
649623
raise NotImplementedError(
650624
f"DateOffset subclass {type(self).__name__} "
651625
"does not have a vectorized implementation"
@@ -1399,8 +1373,7 @@ cdef class RelativeDeltaOffset(BaseOffset):
13991373
"applied vectorized"
14001374
)
14011375

1402-
@apply_array_wraps
1403-
def _apply_array(self, dtarr):
1376+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray:
14041377
reso = get_unit_from_dtype(dtarr.dtype)
14051378
dt64other = np.asarray(dtarr)
14061379

@@ -1814,8 +1787,7 @@ cdef class BusinessDay(BusinessMixin):
18141787
days = n + 2
18151788
return days
18161789

1817-
@apply_array_wraps
1818-
def _apply_array(self, dtarr):
1790+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray:
18191791
i8other = dtarr.view("i8")
18201792
reso = get_unit_from_dtype(dtarr.dtype)
18211793
res = self._shift_bdays(i8other, reso=reso)
@@ -2361,8 +2333,7 @@ cdef class YearOffset(SingleConstructorOffset):
23612333
months = years * 12 + (self.month - other.month)
23622334
return shift_month(other, months, self._day_opt)
23632335

2364-
@apply_array_wraps
2365-
def _apply_array(self, dtarr):
2336+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray:
23662337
reso = get_unit_from_dtype(dtarr.dtype)
23672338
shifted = shift_quarters(
23682339
dtarr.view("i8"), self.n, self.month, self._day_opt, modby=12, reso=reso
@@ -2613,8 +2584,7 @@ cdef class QuarterOffset(SingleConstructorOffset):
26132584
months = qtrs * 3 - months_since
26142585
return shift_month(other, months, self._day_opt)
26152586

2616-
@apply_array_wraps
2617-
def _apply_array(self, dtarr):
2587+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray:
26182588
reso = get_unit_from_dtype(dtarr.dtype)
26192589
shifted = shift_quarters(
26202590
dtarr.view("i8"),
@@ -2798,8 +2768,7 @@ cdef class MonthOffset(SingleConstructorOffset):
27982768
n = roll_convention(other.day, self.n, compare_day)
27992769
return shift_month(other, n, self._day_opt)
28002770

2801-
@apply_array_wraps
2802-
def _apply_array(self, dtarr):
2771+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray:
28032772
reso = get_unit_from_dtype(dtarr.dtype)
28042773
shifted = shift_months(dtarr.view("i8"), self.n, self._day_opt, reso=reso)
28052774
return shifted
@@ -3029,10 +2998,9 @@ cdef class SemiMonthOffset(SingleConstructorOffset):
30292998

30302999
return shift_month(other, months, to_day)
30313000

3032-
@apply_array_wraps
30333001
@cython.wraparound(False)
30343002
@cython.boundscheck(False)
3035-
def _apply_array(self, dtarr):
3003+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray:
30363004
cdef:
30373005
ndarray i8other = dtarr.view("i8")
30383006
Py_ssize_t i, count = dtarr.size
@@ -3254,8 +3222,7 @@ cdef class Week(SingleConstructorOffset):
32543222

32553223
return other + timedelta(weeks=k)
32563224

3257-
@apply_array_wraps
3258-
def _apply_array(self, dtarr):
3225+
def _apply_array(self, dtarr: np.ndarray) -> np.ndarray:
32593226
if self.weekday is None:
32603227
td = timedelta(days=7 * self.n)
32613228
unit = np.datetime_data(dtarr.dtype)[0]

pandas/core/arrays/datetimes.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def _assert_tzawareness_compat(self, other) -> None:
790790
# -----------------------------------------------------------------
791791
# Arithmetic Methods
792792

793-
def _add_offset(self, offset) -> Self:
793+
def _add_offset(self, offset: BaseOffset) -> Self:
794794
assert not isinstance(offset, Tick)
795795

796796
if self.tz is not None:
@@ -799,24 +799,31 @@ def _add_offset(self, offset) -> Self:
799799
values = self
800800

801801
try:
802-
result = offset._apply_array(values)
803-
if result.dtype.kind == "i":
804-
result = result.view(values.dtype)
802+
res_values = offset._apply_array(values._ndarray)
803+
if res_values.dtype.kind == "i":
804+
# error: Argument 1 to "view" of "ndarray" has incompatible type
805+
# "dtype[datetime64] | DatetimeTZDtype"; expected
806+
# "dtype[Any] | type[Any] | _SupportsDType[dtype[Any]]"
807+
res_values = res_values.view(values.dtype) # type: ignore[arg-type]
805808
except NotImplementedError:
806809
warnings.warn(
807810
"Non-vectorized DateOffset being applied to Series or DatetimeIndex.",
808811
PerformanceWarning,
809812
stacklevel=find_stack_level(),
810813
)
811-
result = self.astype("O") + offset
814+
res_values = self.astype("O") + offset
812815
# TODO(GH#55564): as_unit will be unnecessary
813-
result = type(self)._from_sequence(result).as_unit(self.unit)
816+
result = type(self)._from_sequence(res_values).as_unit(self.unit)
814817
if not len(self):
815818
# GH#30336 _from_sequence won't be able to infer self.tz
816819
return result.tz_localize(self.tz)
817820

818821
else:
819-
result = type(self)._simple_new(result, dtype=result.dtype)
822+
result = type(self)._simple_new(res_values, dtype=res_values.dtype)
823+
if offset.normalize:
824+
result = result.normalize()
825+
result._freq = None
826+
820827
if self.tz is not None:
821828
result = result.tz_localize(self.tz)
822829

scripts/run_stubtest.py

-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
"pandas._libs.sparse.SparseIndex.to_block_index",
7070
"pandas._libs.sparse.SparseIndex.to_int_index",
7171
# TODO (decorator changes argument names)
72-
"pandas._libs.tslibs.offsets.BaseOffset._apply_array",
7372
"pandas._libs.tslibs.offsets.BusinessHour.rollback",
7473
"pandas._libs.tslibs.offsets.BusinessHour.rollforward ",
7574
# type alias

0 commit comments

Comments
 (0)