Skip to content

Commit b05e6b1

Browse files
authored
REF: de-duplicate code in liboffsets (#34778)
1 parent 87d8b5e commit b05e6b1

File tree

1 file changed

+79
-207
lines changed

1 file changed

+79
-207
lines changed

pandas/_libs/tslibs/offsets.pyx

+79-207
Original file line numberDiff line numberDiff line change
@@ -3723,136 +3723,14 @@ cdef shift_quarters(
37233723
out : ndarray[int64_t]
37243724
"""
37253725
cdef:
3726-
Py_ssize_t i
3727-
npy_datetimestruct dts
3728-
int count = len(dtindex)
3729-
int months_to_roll, months_since, n, compare_day
3726+
Py_ssize_t count = len(dtindex)
37303727
int64_t[:] out = np.empty(count, dtype="int64")
37313728

3732-
if day_opt == "start":
3733-
with nogil:
3734-
for i in range(count):
3735-
if dtindex[i] == NPY_NAT:
3736-
out[i] = NPY_NAT
3737-
continue
3738-
3739-
dt64_to_dtstruct(dtindex[i], &dts)
3740-
n = quarters
3741-
3742-
months_since = (dts.month - q1start_month) % modby
3743-
compare_day = get_day_of_month(&dts, day_opt)
3744-
3745-
# offset semantics - if on the anchor point and going backwards
3746-
# shift to next
3747-
if n <= 0 and (months_since != 0 or
3748-
(months_since == 0 and dts.day > compare_day)):
3749-
# make sure to roll forward, so negate
3750-
n += 1
3751-
elif n > 0 and (months_since == 0 and dts.day < compare_day):
3752-
# pretend to roll back if on same month but
3753-
# before compare_day
3754-
n -= 1
3755-
3756-
dts.year = year_add_months(dts, modby * n - months_since)
3757-
dts.month = month_add_months(dts, modby * n - months_since)
3758-
dts.day = get_day_of_month(&dts, day_opt)
3759-
3760-
out[i] = dtstruct_to_dt64(&dts)
3761-
3762-
elif day_opt == "end":
3763-
with nogil:
3764-
for i in range(count):
3765-
if dtindex[i] == NPY_NAT:
3766-
out[i] = NPY_NAT
3767-
continue
3768-
3769-
dt64_to_dtstruct(dtindex[i], &dts)
3770-
n = quarters
3771-
3772-
months_since = (dts.month - q1start_month) % modby
3773-
compare_day = get_day_of_month(&dts, day_opt)
3774-
3775-
if n <= 0 and (months_since != 0 or
3776-
(months_since == 0 and dts.day > compare_day)):
3777-
# make sure to roll forward, so negate
3778-
n += 1
3779-
elif n > 0 and (months_since == 0 and dts.day < compare_day):
3780-
# pretend to roll back if on same month but
3781-
# before compare_day
3782-
n -= 1
3783-
3784-
dts.year = year_add_months(dts, modby * n - months_since)
3785-
dts.month = month_add_months(dts, modby * n - months_since)
3786-
dts.day = get_day_of_month(&dts, day_opt)
3787-
3788-
out[i] = dtstruct_to_dt64(&dts)
3789-
3790-
elif day_opt == "business_start":
3791-
with nogil:
3792-
for i in range(count):
3793-
if dtindex[i] == NPY_NAT:
3794-
out[i] = NPY_NAT
3795-
continue
3796-
3797-
dt64_to_dtstruct(dtindex[i], &dts)
3798-
n = quarters
3799-
3800-
months_since = (dts.month - q1start_month) % modby
3801-
# compare_day is only relevant for comparison in the case
3802-
# where months_since == 0.
3803-
compare_day = get_day_of_month(&dts, day_opt)
3804-
3805-
if n <= 0 and (months_since != 0 or
3806-
(months_since == 0 and dts.day > compare_day)):
3807-
# make sure to roll forward, so negate
3808-
n += 1
3809-
elif n > 0 and (months_since == 0 and dts.day < compare_day):
3810-
# pretend to roll back if on same month but
3811-
# before compare_day
3812-
n -= 1
3813-
3814-
dts.year = year_add_months(dts, modby * n - months_since)
3815-
dts.month = month_add_months(dts, modby * n - months_since)
3816-
3817-
dts.day = get_day_of_month(&dts, day_opt)
3818-
3819-
out[i] = dtstruct_to_dt64(&dts)
3820-
3821-
elif day_opt == "business_end":
3822-
with nogil:
3823-
for i in range(count):
3824-
if dtindex[i] == NPY_NAT:
3825-
out[i] = NPY_NAT
3826-
continue
3827-
3828-
dt64_to_dtstruct(dtindex[i], &dts)
3829-
n = quarters
3830-
3831-
months_since = (dts.month - q1start_month) % modby
3832-
# compare_day is only relevant for comparison in the case
3833-
# where months_since == 0.
3834-
compare_day = get_day_of_month(&dts, day_opt)
3835-
3836-
if n <= 0 and (months_since != 0 or
3837-
(months_since == 0 and dts.day > compare_day)):
3838-
# make sure to roll forward, so negate
3839-
n += 1
3840-
elif n > 0 and (months_since == 0 and dts.day < compare_day):
3841-
# pretend to roll back if on same month but
3842-
# before compare_day
3843-
n -= 1
3844-
3845-
dts.year = year_add_months(dts, modby * n - months_since)
3846-
dts.month = month_add_months(dts, modby * n - months_since)
3847-
3848-
dts.day = get_day_of_month(&dts, day_opt)
3849-
3850-
out[i] = dtstruct_to_dt64(&dts)
3851-
3852-
else:
3729+
if day_opt not in ["start", "end", "business_start", "business_end"]:
38533730
raise ValueError("day must be None, 'start', 'end', "
38543731
"'business_start', or 'business_end'")
38553732

3733+
_shift_quarters(dtindex, out, count, quarters, q1start_month, day_opt, modby)
38563734
return np.asarray(out)
38573735

38583736

@@ -3872,7 +3750,6 @@ def shift_months(const int64_t[:] dtindex, int months, object day_opt=None):
38723750
Py_ssize_t i
38733751
npy_datetimestruct dts
38743752
int count = len(dtindex)
3875-
int months_to_roll
38763753
int64_t[:] out = np.empty(count, dtype="int64")
38773754

38783755
if day_opt is None:
@@ -3888,94 +3765,90 @@ def shift_months(const int64_t[:] dtindex, int months, object day_opt=None):
38883765

38893766
dts.day = min(dts.day, get_days_in_month(dts.year, dts.month))
38903767
out[i] = dtstruct_to_dt64(&dts)
3891-
elif day_opt == "start":
3892-
with nogil:
3893-
for i in range(count):
3894-
if dtindex[i] == NPY_NAT:
3895-
out[i] = NPY_NAT
3896-
continue
3897-
3898-
dt64_to_dtstruct(dtindex[i], &dts)
3899-
months_to_roll = months
3900-
compare_day = get_day_of_month(&dts, day_opt)
3768+
elif day_opt in ["start", "end", "business_start", "business_end"]:
3769+
_shift_months(dtindex, out, count, months, day_opt)
39013770

3902-
# offset semantics - if on the anchor point and going backwards
3903-
# shift to next
3904-
months_to_roll = roll_convention(dts.day, months_to_roll,
3905-
compare_day)
3906-
3907-
dts.year = year_add_months(dts, months_to_roll)
3908-
dts.month = month_add_months(dts, months_to_roll)
3909-
dts.day = get_day_of_month(&dts, day_opt)
3910-
3911-
out[i] = dtstruct_to_dt64(&dts)
3912-
elif day_opt == "end":
3913-
with nogil:
3914-
for i in range(count):
3915-
if dtindex[i] == NPY_NAT:
3916-
out[i] = NPY_NAT
3917-
continue
3771+
else:
3772+
raise ValueError("day must be None, 'start', 'end', "
3773+
"'business_start', or 'business_end'")
39183774

3919-
dt64_to_dtstruct(dtindex[i], &dts)
3920-
months_to_roll = months
3921-
compare_day = get_day_of_month(&dts, day_opt)
3775+
return np.asarray(out)
39223776

3923-
# similar semantics - when adding shift forward by one
3924-
# month if already at an end of month
3925-
months_to_roll = roll_convention(dts.day, months_to_roll,
3926-
compare_day)
39273777

3928-
dts.year = year_add_months(dts, months_to_roll)
3929-
dts.month = month_add_months(dts, months_to_roll)
3778+
@cython.wraparound(False)
3779+
@cython.boundscheck(False)
3780+
cdef inline void _shift_months(const int64_t[:] dtindex,
3781+
int64_t[:] out,
3782+
Py_ssize_t count,
3783+
int months,
3784+
str day_opt) nogil:
3785+
"""See shift_months.__doc__"""
3786+
cdef:
3787+
Py_ssize_t i
3788+
int months_to_roll, compare_day
3789+
npy_datetimestruct dts
39303790

3931-
dts.day = get_day_of_month(&dts, day_opt)
3932-
out[i] = dtstruct_to_dt64(&dts)
3791+
for i in range(count):
3792+
if dtindex[i] == NPY_NAT:
3793+
out[i] = NPY_NAT
3794+
continue
39333795

3934-
elif day_opt == "business_start":
3935-
with nogil:
3936-
for i in range(count):
3937-
if dtindex[i] == NPY_NAT:
3938-
out[i] = NPY_NAT
3939-
continue
3796+
dt64_to_dtstruct(dtindex[i], &dts)
3797+
months_to_roll = months
3798+
compare_day = get_day_of_month(&dts, day_opt)
39403799

3941-
dt64_to_dtstruct(dtindex[i], &dts)
3942-
months_to_roll = months
3943-
compare_day = get_day_of_month(&dts, day_opt)
3800+
months_to_roll = roll_convention(dts.day, months_to_roll,
3801+
compare_day)
39443802

3945-
months_to_roll = roll_convention(dts.day, months_to_roll,
3946-
compare_day)
3803+
dts.year = year_add_months(dts, months_to_roll)
3804+
dts.month = month_add_months(dts, months_to_roll)
3805+
dts.day = get_day_of_month(&dts, day_opt)
39473806

3948-
dts.year = year_add_months(dts, months_to_roll)
3949-
dts.month = month_add_months(dts, months_to_roll)
3807+
out[i] = dtstruct_to_dt64(&dts)
39503808

3951-
dts.day = get_day_of_month(&dts, day_opt)
3952-
out[i] = dtstruct_to_dt64(&dts)
39533809

3954-
elif day_opt == "business_end":
3955-
with nogil:
3956-
for i in range(count):
3957-
if dtindex[i] == NPY_NAT:
3958-
out[i] = NPY_NAT
3959-
continue
3810+
@cython.wraparound(False)
3811+
@cython.boundscheck(False)
3812+
cdef inline void _shift_quarters(const int64_t[:] dtindex,
3813+
int64_t[:] out,
3814+
Py_ssize_t count,
3815+
int quarters,
3816+
int q1start_month,
3817+
str day_opt,
3818+
int modby) nogil:
3819+
"""See shift_quarters.__doc__"""
3820+
cdef:
3821+
Py_ssize_t i
3822+
int months_since, compare_day, n
3823+
npy_datetimestruct dts
39603824

3961-
dt64_to_dtstruct(dtindex[i], &dts)
3962-
months_to_roll = months
3963-
compare_day = get_day_of_month(&dts, day_opt)
3825+
for i in range(count):
3826+
if dtindex[i] == NPY_NAT:
3827+
out[i] = NPY_NAT
3828+
continue
39643829

3965-
months_to_roll = roll_convention(dts.day, months_to_roll,
3966-
compare_day)
3830+
dt64_to_dtstruct(dtindex[i], &dts)
3831+
n = quarters
39673832

3968-
dts.year = year_add_months(dts, months_to_roll)
3969-
dts.month = month_add_months(dts, months_to_roll)
3833+
months_since = (dts.month - q1start_month) % modby
3834+
compare_day = get_day_of_month(&dts, day_opt)
39703835

3971-
dts.day = get_day_of_month(&dts, day_opt)
3972-
out[i] = dtstruct_to_dt64(&dts)
3836+
# offset semantics - if on the anchor point and going backwards
3837+
# shift to next
3838+
if n <= 0 and (months_since != 0 or
3839+
(months_since == 0 and dts.day > compare_day)):
3840+
# make sure to roll forward, so negate
3841+
n += 1
3842+
elif n > 0 and (months_since == 0 and dts.day < compare_day):
3843+
# pretend to roll back if on same month but
3844+
# before compare_day
3845+
n -= 1
39733846

3974-
else:
3975-
raise ValueError("day must be None, 'start', 'end', "
3976-
"'business_start', or 'business_end'")
3847+
dts.year = year_add_months(dts, modby * n - months_since)
3848+
dts.month = month_add_months(dts, modby * n - months_since)
3849+
dts.day = get_day_of_month(&dts, day_opt)
39773850

3978-
return np.asarray(out)
3851+
out[i] = dtstruct_to_dt64(&dts)
39793852

39803853

39813854
cdef ndarray[int64_t] shift_bdays(const int64_t[:] i8other, int periods):
@@ -4035,8 +3908,7 @@ cdef ndarray[int64_t] shift_bdays(const int64_t[:] i8other, int periods):
40353908
return result.base
40363909

40373910

4038-
def shift_month(stamp: datetime, months: int,
4039-
day_opt: object=None) -> datetime:
3911+
def shift_month(stamp: datetime, months: int, day_opt: object=None) -> datetime:
40403912
"""
40413913
Given a datetime (or Timestamp) `stamp`, an integer `months` and an
40423914
option `day_opt`, return a new datetimelike that many months later,
@@ -4078,14 +3950,14 @@ def shift_month(stamp: datetime, months: int,
40783950
if day_opt is None:
40793951
days_in_month = get_days_in_month(year, month)
40803952
day = min(stamp.day, days_in_month)
4081-
elif day_opt == 'start':
3953+
elif day_opt == "start":
40823954
day = 1
4083-
elif day_opt == 'end':
3955+
elif day_opt == "end":
40843956
day = get_days_in_month(year, month)
4085-
elif day_opt == 'business_start':
3957+
elif day_opt == "business_start":
40863958
# first business day of month
40873959
day = get_firstbday(year, month)
4088-
elif day_opt == 'business_end':
3960+
elif day_opt == "business_end":
40893961
# last business day of month
40903962
day = get_lastbday(year, month)
40913963
elif is_integer_object(day_opt):
@@ -4126,15 +3998,15 @@ cdef inline int get_day_of_month(npy_datetimestruct* dts, day_opt) nogil except?
41263998
cdef:
41273999
int days_in_month
41284000

4129-
if day_opt == 'start':
4001+
if day_opt == "start":
41304002
return 1
4131-
elif day_opt == 'end':
4003+
elif day_opt == "end":
41324004
days_in_month = get_days_in_month(dts.year, dts.month)
41334005
return days_in_month
4134-
elif day_opt == 'business_start':
4006+
elif day_opt == "business_start":
41354007
# first business day of month
41364008
return get_firstbday(dts.year, dts.month)
4137-
elif day_opt == 'business_end':
4009+
elif day_opt == "business_end":
41384010
# last business day of month
41394011
return get_lastbday(dts.year, dts.month)
41404012
elif day_opt is not None:

0 commit comments

Comments
 (0)