Skip to content

POC: rollback_array #52205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 93 additions & 15 deletions pandas/_libs/tslibs/offsets.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,20 @@ cdef class YearOffset(SingleConstructorOffset):
)
return shifted

@apply_array_wraps
def rollback_array(self, dtarr):
reso = get_unit_from_dtype(dtarr.dtype)
shifted = shift_quarters(
dtarr.view("i8"),
self.n,
self.month,
self._day_opt,
modby=12,
reso=reso,
roll=True,
)
return shifted


cdef class BYearEnd(YearOffset):
"""
Expand Down Expand Up @@ -2406,6 +2420,20 @@ cdef class QuarterOffset(SingleConstructorOffset):
)
return shifted

@apply_array_wraps
def rollback_array(self, dtarr):
reso = get_unit_from_dtype(dtarr.dtype)
shifted = shift_quarters(
dtarr.view("i8"),
self.n,
self.startingMonth,
self._day_opt,
modby=3,
reso=reso,
roll=True,
)
return shifted


cdef class BQuarterEnd(QuarterOffset):
"""
Expand Down Expand Up @@ -2532,6 +2560,18 @@ cdef class MonthOffset(SingleConstructorOffset):
shifted = shift_months(dtarr.view("i8"), self.n, self._day_opt, reso=reso)
return shifted

@apply_array_wraps
def rollback_array(self, dtarr):
reso = get_unit_from_dtype(dtarr.dtype)
shifted = shift_months(
dtarr.view("i8"),
self.n,
self._day_opt,
reso=reso,
roll=True,
)
return shifted

cpdef __setstate__(self, state):
state.pop("_use_relativedelta", False)
state.pop("offset", None)
Expand Down Expand Up @@ -4310,6 +4350,7 @@ cdef ndarray shift_quarters(
str day_opt,
int modby=3,
NPY_DATETIMEUNIT reso=NPY_DATETIMEUNIT.NPY_FR_ns,
bint roll=False,
):
"""
Given an int64 array representing nanosecond timestamps, shift all elements
Expand Down Expand Up @@ -4353,13 +4394,26 @@ cdef ndarray shift_quarters(
n = quarters

months_since = (dts.month - q1start_month) % modby
n = _roll_qtrday(&dts, n, months_since, day_opt)
if roll:
if months_since == 0 and dts.day == get_day_of_month(&dts, day_opt):
# already on_offset
res_val = val
else:
n = _roll_qtrday(&dts, -1, months_since, day_opt)

dts.year = year_add_months(dts, modby * n - months_since)
dts.month = month_add_months(dts, modby * n - months_since)
dts.day = get_day_of_month(&dts, day_opt)

res_val = npy_datetimestruct_to_datetime(reso, &dts)
else:
n = _roll_qtrday(&dts, n, months_since, day_opt)

dts.year = year_add_months(dts, modby * n - months_since)
dts.month = month_add_months(dts, modby * n - months_since)
dts.day = get_day_of_month(&dts, day_opt)
dts.year = year_add_months(dts, modby * n - months_since)
dts.month = month_add_months(dts, modby * n - months_since)
dts.day = get_day_of_month(&dts, day_opt)

res_val = npy_datetimestruct_to_datetime(reso, &dts)
res_val = npy_datetimestruct_to_datetime(reso, &dts)

# Analogous to: out[i] = res_val
(<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 0))[0] = res_val
Expand All @@ -4376,6 +4430,7 @@ def shift_months(
int months,
str day_opt=None,
NPY_DATETIMEUNIT reso=NPY_DATETIMEUNIT.NPY_FR_ns,
bint roll=False,
):
"""
Given an int64-based datetime index, shift all elements
Expand Down Expand Up @@ -4413,11 +4468,23 @@ def shift_months(
res_val = NPY_NAT
else:
pandas_datetime_to_datetimestruct(val, reso, &dts)
dts.year = year_add_months(dts, months)
dts.month = month_add_months(dts, months)
if roll:
if dts.day == get_days_in_month(dts.year, dts.month):
# i.e. we are on_offset
res_val = val
else:
# Roll back to the previous month
dts.year = year_add_months(dts, -1)
dts.month = month_add_months(dts, -1)
dts.day = get_days_in_month(dts.year, dts.month)
res_val = npy_datetimestruct_to_datetime(reso, &dts)

dts.day = min(dts.day, get_days_in_month(dts.year, dts.month))
res_val = npy_datetimestruct_to_datetime(reso, &dts)
else:
dts.year = year_add_months(dts, months)
dts.month = month_add_months(dts, months)

dts.day = min(dts.day, get_days_in_month(dts.year, dts.month))
res_val = npy_datetimestruct_to_datetime(reso, &dts)

# Analogous to: out[i] = res_val
(<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 0))[0] = res_val
Expand All @@ -4435,15 +4502,26 @@ def shift_months(
res_val = NPY_NAT
else:
pandas_datetime_to_datetimestruct(val, reso, &dts)
months_to_roll = months
if roll:
if dts.day == get_day_of_month(&dts, day_opt):
# i.e. we are already on_offset
res_val = val
else:
months_to_roll = _roll_qtrday(&dts, -1, 0, day_opt)
dts.year = year_add_months(dts, months_to_roll)
dts.month = month_add_months(dts, months_to_roll)
dts.day = get_day_of_month(&dts, day_opt)

res_val = npy_datetimestruct_to_datetime(reso, &dts)

months_to_roll = _roll_qtrday(&dts, months_to_roll, 0, day_opt)
else:
months_to_roll = _roll_qtrday(&dts, months, 0, day_opt)

dts.year = year_add_months(dts, months_to_roll)
dts.month = month_add_months(dts, months_to_roll)
dts.day = get_day_of_month(&dts, day_opt)
dts.year = year_add_months(dts, months_to_roll)
dts.month = month_add_months(dts, months_to_roll)
dts.day = get_day_of_month(&dts, day_opt)

res_val = npy_datetimestruct_to_datetime(reso, &dts)
res_val = npy_datetimestruct_to_datetime(reso, &dts)

# Analogous to: out[i] = res_val
(<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 0))[0] = res_val
Expand Down