diff --git a/pandas/_libs/tslibs/fields.pyx b/pandas/_libs/tslibs/fields.pyx index 16fa05c3801c6..57404b99c7628 100644 --- a/pandas/_libs/tslibs/fields.pyx +++ b/pandas/_libs/tslibs/fields.pyx @@ -174,6 +174,18 @@ def get_date_name_field(const int64_t[:] dtindex, str field, object locale=None) return out +cdef inline bint _is_on_month(int month, int compare_month, int modby) nogil: + """ + Analogous to DateOffset.is_on_offset checking for the month part of a date. + """ + if modby == 1: + return True + elif modby == 3: + return (month - compare_month) % 3 == 0 + else: + return month == compare_month + + @cython.wraparound(False) @cython.boundscheck(False) def get_start_end_field(const int64_t[:] dtindex, str field, @@ -191,6 +203,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field, int start_month = 1 ndarray[int8_t] out npy_datetimestruct dts + int compare_month, modby out = np.zeros(count, dtype='int8') @@ -215,102 +228,15 @@ def get_start_end_field(const int64_t[:] dtindex, str field, end_month = 12 start_month = 1 - if field == 'is_month_start': - if is_business: - for i in range(count): - if dtindex[i] == NPY_NAT: - out[i] = 0 - continue - - dt64_to_dtstruct(dtindex[i], &dts) - - if dts.day == get_firstbday(dts.year, dts.month): - out[i] = 1 - - else: - for i in range(count): - if dtindex[i] == NPY_NAT: - out[i] = 0 - continue - - dt64_to_dtstruct(dtindex[i], &dts) - - if dts.day == 1: - out[i] = 1 - - elif field == 'is_month_end': - if is_business: - for i in range(count): - if dtindex[i] == NPY_NAT: - out[i] = 0 - continue - - dt64_to_dtstruct(dtindex[i], &dts) - - if dts.day == get_lastbday(dts.year, dts.month): - out[i] = 1 - - else: - for i in range(count): - if dtindex[i] == NPY_NAT: - out[i] = 0 - continue - - dt64_to_dtstruct(dtindex[i], &dts) - - if dts.day == get_days_in_month(dts.year, dts.month): - out[i] = 1 - - elif field == 'is_quarter_start': - if is_business: - for i in range(count): - if dtindex[i] == NPY_NAT: - out[i] = 0 - continue - - dt64_to_dtstruct(dtindex[i], &dts) - - if ((dts.month - start_month) % 3 == 0) and ( - dts.day == get_firstbday(dts.year, dts.month)): - out[i] = 1 - - else: - for i in range(count): - if dtindex[i] == NPY_NAT: - out[i] = 0 - continue - - dt64_to_dtstruct(dtindex[i], &dts) - - if ((dts.month - start_month) % 3 == 0) and dts.day == 1: - out[i] = 1 - - elif field == 'is_quarter_end': - if is_business: - for i in range(count): - if dtindex[i] == NPY_NAT: - out[i] = 0 - continue - - dt64_to_dtstruct(dtindex[i], &dts) - - if ((dts.month - end_month) % 3 == 0) and ( - dts.day == get_lastbday(dts.year, dts.month)): - out[i] = 1 - - else: - for i in range(count): - if dtindex[i] == NPY_NAT: - out[i] = 0 - continue - - dt64_to_dtstruct(dtindex[i], &dts) - - if ((dts.month - end_month) % 3 == 0) and ( - dts.day == get_days_in_month(dts.year, dts.month)): - out[i] = 1 + compare_month = start_month if "start" in field else end_month + if "month" in field: + modby = 1 + elif "quarter" in field: + modby = 3 + else: + modby = 12 - elif field == 'is_year_start': + if field in ["is_month_start", "is_quarter_start", "is_year_start"]: if is_business: for i in range(count): if dtindex[i] == NPY_NAT: @@ -319,7 +245,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field, dt64_to_dtstruct(dtindex[i], &dts) - if (dts.month == start_month) and ( + if _is_on_month(dts.month, compare_month, modby) and ( dts.day == get_firstbday(dts.year, dts.month)): out[i] = 1 @@ -331,10 +257,10 @@ def get_start_end_field(const int64_t[:] dtindex, str field, dt64_to_dtstruct(dtindex[i], &dts) - if (dts.month == start_month) and dts.day == 1: + if _is_on_month(dts.month, compare_month, modby) and dts.day == 1: out[i] = 1 - elif field == 'is_year_end': + elif field in ["is_month_end", "is_quarter_end", "is_year_end"]: if is_business: for i in range(count): if dtindex[i] == NPY_NAT: @@ -343,7 +269,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field, dt64_to_dtstruct(dtindex[i], &dts) - if (dts.month == end_month) and ( + if _is_on_month(dts.month, compare_month, modby) and ( dts.day == get_lastbday(dts.year, dts.month)): out[i] = 1 @@ -355,7 +281,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field, dt64_to_dtstruct(dtindex[i], &dts) - if (dts.month == end_month) and ( + if _is_on_month(dts.month, compare_month, modby) and ( dts.day == get_days_in_month(dts.year, dts.month)): out[i] = 1