@@ -174,6 +174,18 @@ def get_date_name_field(const int64_t[:] dtindex, str field, object locale=None)
174
174
return out
175
175
176
176
177
+ cdef inline bint _is_on_month(int month, int compare_month, int modby) nogil:
178
+ """
179
+ Analogous to DateOffset.is_on_offset checking for the month part of a date.
180
+ """
181
+ if modby == 1 :
182
+ return True
183
+ elif modby == 3 :
184
+ return (month - compare_month) % 3 == 0
185
+ else :
186
+ return month == compare_month
187
+
188
+
177
189
@ cython.wraparound (False )
178
190
@ cython.boundscheck (False )
179
191
def get_start_end_field (const int64_t[:] dtindex , str field ,
@@ -191,6 +203,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
191
203
int start_month = 1
192
204
ndarray[int8_t] out
193
205
npy_datetimestruct dts
206
+ int compare_month, modby
194
207
195
208
out = np.zeros(count, dtype = ' int8' )
196
209
@@ -215,102 +228,15 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
215
228
end_month = 12
216
229
start_month = 1
217
230
218
- if field == ' is_month_start' :
219
- if is_business:
220
- for i in range (count):
221
- if dtindex[i] == NPY_NAT:
222
- out[i] = 0
223
- continue
224
-
225
- dt64_to_dtstruct(dtindex[i], & dts)
226
-
227
- if dts.day == get_firstbday(dts.year, dts.month):
228
- out[i] = 1
229
-
230
- else :
231
- for i in range (count):
232
- if dtindex[i] == NPY_NAT:
233
- out[i] = 0
234
- continue
235
-
236
- dt64_to_dtstruct(dtindex[i], & dts)
237
-
238
- if dts.day == 1 :
239
- out[i] = 1
240
-
241
- elif field == ' is_month_end' :
242
- if is_business:
243
- for i in range (count):
244
- if dtindex[i] == NPY_NAT:
245
- out[i] = 0
246
- continue
247
-
248
- dt64_to_dtstruct(dtindex[i], & dts)
249
-
250
- if dts.day == get_lastbday(dts.year, dts.month):
251
- out[i] = 1
252
-
253
- else :
254
- for i in range (count):
255
- if dtindex[i] == NPY_NAT:
256
- out[i] = 0
257
- continue
258
-
259
- dt64_to_dtstruct(dtindex[i], & dts)
260
-
261
- if dts.day == get_days_in_month(dts.year, dts.month):
262
- out[i] = 1
263
-
264
- elif field == ' is_quarter_start' :
265
- if is_business:
266
- for i in range (count):
267
- if dtindex[i] == NPY_NAT:
268
- out[i] = 0
269
- continue
270
-
271
- dt64_to_dtstruct(dtindex[i], & dts)
272
-
273
- if ((dts.month - start_month) % 3 == 0 ) and (
274
- dts.day == get_firstbday(dts.year, dts.month)):
275
- out[i] = 1
276
-
277
- else :
278
- for i in range (count):
279
- if dtindex[i] == NPY_NAT:
280
- out[i] = 0
281
- continue
282
-
283
- dt64_to_dtstruct(dtindex[i], & dts)
284
-
285
- if ((dts.month - start_month) % 3 == 0 ) and dts.day == 1 :
286
- out[i] = 1
287
-
288
- elif field == ' is_quarter_end' :
289
- if is_business:
290
- for i in range (count):
291
- if dtindex[i] == NPY_NAT:
292
- out[i] = 0
293
- continue
294
-
295
- dt64_to_dtstruct(dtindex[i], & dts)
296
-
297
- if ((dts.month - end_month) % 3 == 0 ) and (
298
- dts.day == get_lastbday(dts.year, dts.month)):
299
- out[i] = 1
300
-
301
- else :
302
- for i in range (count):
303
- if dtindex[i] == NPY_NAT:
304
- out[i] = 0
305
- continue
306
-
307
- dt64_to_dtstruct(dtindex[i], & dts)
308
-
309
- if ((dts.month - end_month) % 3 == 0 ) and (
310
- dts.day == get_days_in_month(dts.year, dts.month)):
311
- out[i] = 1
231
+ compare_month = start_month if " start" in field else end_month
232
+ if " month" in field:
233
+ modby = 1
234
+ elif " quarter" in field:
235
+ modby = 3
236
+ else :
237
+ modby = 12
312
238
313
- elif field == ' is_year_start' :
239
+ if field in [ " is_month_start " , " is_quarter_start " , " is_year_start" ] :
314
240
if is_business:
315
241
for i in range (count):
316
242
if dtindex[i] == NPY_NAT:
@@ -319,7 +245,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
319
245
320
246
dt64_to_dtstruct(dtindex[i], & dts)
321
247
322
- if (dts.month == start_month ) and (
248
+ if _is_on_month (dts.month, compare_month, modby ) and (
323
249
dts.day == get_firstbday(dts.year, dts.month)):
324
250
out[i] = 1
325
251
@@ -331,10 +257,10 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
331
257
332
258
dt64_to_dtstruct(dtindex[i], & dts)
333
259
334
- if (dts.month == start_month ) and dts.day == 1 :
260
+ if _is_on_month (dts.month, compare_month, modby ) and dts.day == 1 :
335
261
out[i] = 1
336
262
337
- elif field == ' is_year_end' :
263
+ elif field in [ " is_month_end " , " is_quarter_end " , " is_year_end" ] :
338
264
if is_business:
339
265
for i in range (count):
340
266
if dtindex[i] == NPY_NAT:
@@ -343,7 +269,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
343
269
344
270
dt64_to_dtstruct(dtindex[i], & dts)
345
271
346
- if (dts.month == end_month ) and (
272
+ if _is_on_month (dts.month, compare_month, modby ) and (
347
273
dts.day == get_lastbday(dts.year, dts.month)):
348
274
out[i] = 1
349
275
@@ -355,7 +281,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
355
281
356
282
dt64_to_dtstruct(dtindex[i], & dts)
357
283
358
- if (dts.month == end_month ) and (
284
+ if _is_on_month (dts.month, compare_month, modby ) and (
359
285
dts.day == get_days_in_month(dts.year, dts.month)):
360
286
out[i] = 1
361
287
0 commit comments