Skip to content

REF: de-duplicate tslibs.fields #38950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 26 additions & 100 deletions pandas/_libs/tslibs/fields.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ def get_date_name_field(const int64_t[:] dtindex, str field, object locale=None)
return out


cdef inline bint _is_on_month(int month, int compare_month, int modby) nogil:
"""
Analogous to DateOffset.is_on_offset checking for the month part of a date.
"""
if modby == 1:
return True
elif modby == 3:
return (month - compare_month) % 3 == 0
else:
return month == compare_month


@cython.wraparound(False)
@cython.boundscheck(False)
def get_start_end_field(const int64_t[:] dtindex, str field,
Expand All @@ -191,6 +203,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
int start_month = 1
ndarray[int8_t] out
npy_datetimestruct dts
int compare_month, modby

out = np.zeros(count, dtype='int8')

Expand All @@ -215,102 +228,15 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
end_month = 12
start_month = 1

if field == 'is_month_start':
if is_business:
for i in range(count):
if dtindex[i] == NPY_NAT:
out[i] = 0
continue

dt64_to_dtstruct(dtindex[i], &dts)

if dts.day == get_firstbday(dts.year, dts.month):
out[i] = 1

else:
for i in range(count):
if dtindex[i] == NPY_NAT:
out[i] = 0
continue

dt64_to_dtstruct(dtindex[i], &dts)

if dts.day == 1:
out[i] = 1

elif field == 'is_month_end':
if is_business:
for i in range(count):
if dtindex[i] == NPY_NAT:
out[i] = 0
continue

dt64_to_dtstruct(dtindex[i], &dts)

if dts.day == get_lastbday(dts.year, dts.month):
out[i] = 1

else:
for i in range(count):
if dtindex[i] == NPY_NAT:
out[i] = 0
continue

dt64_to_dtstruct(dtindex[i], &dts)

if dts.day == get_days_in_month(dts.year, dts.month):
out[i] = 1

elif field == 'is_quarter_start':
if is_business:
for i in range(count):
if dtindex[i] == NPY_NAT:
out[i] = 0
continue

dt64_to_dtstruct(dtindex[i], &dts)

if ((dts.month - start_month) % 3 == 0) and (
dts.day == get_firstbday(dts.year, dts.month)):
out[i] = 1

else:
for i in range(count):
if dtindex[i] == NPY_NAT:
out[i] = 0
continue

dt64_to_dtstruct(dtindex[i], &dts)

if ((dts.month - start_month) % 3 == 0) and dts.day == 1:
out[i] = 1

elif field == 'is_quarter_end':
if is_business:
for i in range(count):
if dtindex[i] == NPY_NAT:
out[i] = 0
continue

dt64_to_dtstruct(dtindex[i], &dts)

if ((dts.month - end_month) % 3 == 0) and (
dts.day == get_lastbday(dts.year, dts.month)):
out[i] = 1

else:
for i in range(count):
if dtindex[i] == NPY_NAT:
out[i] = 0
continue

dt64_to_dtstruct(dtindex[i], &dts)

if ((dts.month - end_month) % 3 == 0) and (
dts.day == get_days_in_month(dts.year, dts.month)):
out[i] = 1
compare_month = start_month if "start" in field else end_month
if "month" in field:
modby = 1
elif "quarter" in field:
modby = 3
else:
modby = 12

elif field == 'is_year_start':
if field in ["is_month_start", "is_quarter_start", "is_year_start"]:
if is_business:
for i in range(count):
if dtindex[i] == NPY_NAT:
Expand All @@ -319,7 +245,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,

dt64_to_dtstruct(dtindex[i], &dts)

if (dts.month == start_month) and (
if _is_on_month(dts.month, compare_month, modby) and (
dts.day == get_firstbday(dts.year, dts.month)):
out[i] = 1

Expand All @@ -331,10 +257,10 @@ def get_start_end_field(const int64_t[:] dtindex, str field,

dt64_to_dtstruct(dtindex[i], &dts)

if (dts.month == start_month) and dts.day == 1:
if _is_on_month(dts.month, compare_month, modby) and dts.day == 1:
out[i] = 1

elif field == 'is_year_end':
elif field in ["is_month_end", "is_quarter_end", "is_year_end"]:
if is_business:
for i in range(count):
if dtindex[i] == NPY_NAT:
Expand All @@ -343,7 +269,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,

dt64_to_dtstruct(dtindex[i], &dts)

if (dts.month == end_month) and (
if _is_on_month(dts.month, compare_month, modby) and (
dts.day == get_lastbday(dts.year, dts.month)):
out[i] = 1

Expand All @@ -355,7 +281,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,

dt64_to_dtstruct(dtindex[i], &dts)

if (dts.month == end_month) and (
if _is_on_month(dts.month, compare_month, modby) and (
dts.day == get_days_in_month(dts.year, dts.month)):
out[i] = 1

Expand Down