Skip to content

Commit 0cbcbf6

Browse files
jbrockmendelluckyvs1
authored andcommitted
REF: de-duplicate tslibs.fields (pandas-dev#38950)
1 parent 0ce0f6b commit 0cbcbf6

File tree

1 file changed

+26
-100
lines changed

1 file changed

+26
-100
lines changed

pandas/_libs/tslibs/fields.pyx

+26-100
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ def get_date_name_field(const int64_t[:] dtindex, str field, object locale=None)
174174
return out
175175

176176

177+
cdef inline bint _is_on_month(int month, int compare_month, int modby) nogil:
178+
"""
179+
Analogous to DateOffset.is_on_offset checking for the month part of a date.
180+
"""
181+
if modby == 1:
182+
return True
183+
elif modby == 3:
184+
return (month - compare_month) % 3 == 0
185+
else:
186+
return month == compare_month
187+
188+
177189
@cython.wraparound(False)
178190
@cython.boundscheck(False)
179191
def get_start_end_field(const int64_t[:] dtindex, str field,
@@ -191,6 +203,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
191203
int start_month = 1
192204
ndarray[int8_t] out
193205
npy_datetimestruct dts
206+
int compare_month, modby
194207

195208
out = np.zeros(count, dtype='int8')
196209

@@ -215,102 +228,15 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
215228
end_month = 12
216229
start_month = 1
217230

218-
if field == 'is_month_start':
219-
if is_business:
220-
for i in range(count):
221-
if dtindex[i] == NPY_NAT:
222-
out[i] = 0
223-
continue
224-
225-
dt64_to_dtstruct(dtindex[i], &dts)
226-
227-
if dts.day == get_firstbday(dts.year, dts.month):
228-
out[i] = 1
229-
230-
else:
231-
for i in range(count):
232-
if dtindex[i] == NPY_NAT:
233-
out[i] = 0
234-
continue
235-
236-
dt64_to_dtstruct(dtindex[i], &dts)
237-
238-
if dts.day == 1:
239-
out[i] = 1
240-
241-
elif field == 'is_month_end':
242-
if is_business:
243-
for i in range(count):
244-
if dtindex[i] == NPY_NAT:
245-
out[i] = 0
246-
continue
247-
248-
dt64_to_dtstruct(dtindex[i], &dts)
249-
250-
if dts.day == get_lastbday(dts.year, dts.month):
251-
out[i] = 1
252-
253-
else:
254-
for i in range(count):
255-
if dtindex[i] == NPY_NAT:
256-
out[i] = 0
257-
continue
258-
259-
dt64_to_dtstruct(dtindex[i], &dts)
260-
261-
if dts.day == get_days_in_month(dts.year, dts.month):
262-
out[i] = 1
263-
264-
elif field == 'is_quarter_start':
265-
if is_business:
266-
for i in range(count):
267-
if dtindex[i] == NPY_NAT:
268-
out[i] = 0
269-
continue
270-
271-
dt64_to_dtstruct(dtindex[i], &dts)
272-
273-
if ((dts.month - start_month) % 3 == 0) and (
274-
dts.day == get_firstbday(dts.year, dts.month)):
275-
out[i] = 1
276-
277-
else:
278-
for i in range(count):
279-
if dtindex[i] == NPY_NAT:
280-
out[i] = 0
281-
continue
282-
283-
dt64_to_dtstruct(dtindex[i], &dts)
284-
285-
if ((dts.month - start_month) % 3 == 0) and dts.day == 1:
286-
out[i] = 1
287-
288-
elif field == 'is_quarter_end':
289-
if is_business:
290-
for i in range(count):
291-
if dtindex[i] == NPY_NAT:
292-
out[i] = 0
293-
continue
294-
295-
dt64_to_dtstruct(dtindex[i], &dts)
296-
297-
if ((dts.month - end_month) % 3 == 0) and (
298-
dts.day == get_lastbday(dts.year, dts.month)):
299-
out[i] = 1
300-
301-
else:
302-
for i in range(count):
303-
if dtindex[i] == NPY_NAT:
304-
out[i] = 0
305-
continue
306-
307-
dt64_to_dtstruct(dtindex[i], &dts)
308-
309-
if ((dts.month - end_month) % 3 == 0) and (
310-
dts.day == get_days_in_month(dts.year, dts.month)):
311-
out[i] = 1
231+
compare_month = start_month if "start" in field else end_month
232+
if "month" in field:
233+
modby = 1
234+
elif "quarter" in field:
235+
modby = 3
236+
else:
237+
modby = 12
312238

313-
elif field == 'is_year_start':
239+
if field in ["is_month_start", "is_quarter_start", "is_year_start"]:
314240
if is_business:
315241
for i in range(count):
316242
if dtindex[i] == NPY_NAT:
@@ -319,7 +245,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
319245

320246
dt64_to_dtstruct(dtindex[i], &dts)
321247

322-
if (dts.month == start_month) and (
248+
if _is_on_month(dts.month, compare_month, modby) and (
323249
dts.day == get_firstbday(dts.year, dts.month)):
324250
out[i] = 1
325251

@@ -331,10 +257,10 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
331257

332258
dt64_to_dtstruct(dtindex[i], &dts)
333259

334-
if (dts.month == start_month) and dts.day == 1:
260+
if _is_on_month(dts.month, compare_month, modby) and dts.day == 1:
335261
out[i] = 1
336262

337-
elif field == 'is_year_end':
263+
elif field in ["is_month_end", "is_quarter_end", "is_year_end"]:
338264
if is_business:
339265
for i in range(count):
340266
if dtindex[i] == NPY_NAT:
@@ -343,7 +269,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
343269

344270
dt64_to_dtstruct(dtindex[i], &dts)
345271

346-
if (dts.month == end_month) and (
272+
if _is_on_month(dts.month, compare_month, modby) and (
347273
dts.day == get_lastbday(dts.year, dts.month)):
348274
out[i] = 1
349275

@@ -355,7 +281,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
355281

356282
dt64_to_dtstruct(dtindex[i], &dts)
357283

358-
if (dts.month == end_month) and (
284+
if _is_on_month(dts.month, compare_month, modby) and (
359285
dts.day == get_days_in_month(dts.year, dts.month)):
360286
out[i] = 1
361287

0 commit comments

Comments
 (0)