@@ -355,36 +355,44 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
355
355
356
356
@ cython.wraparound (False )
357
357
@ cython.boundscheck (False )
358
- def dt64arr_to_periodarr (const int64_t[:] stamps , int freq , tzinfo tz ):
358
+ def dt64arr_to_periodarr (ndarray stamps , int freq , tzinfo tz ):
359
+ # stamps is int64_t, arbitrary ndim
359
360
cdef:
360
361
Localizer info = Localizer(tz)
361
- int64_t utc_val, local_val
362
- Py_ssize_t pos, i, n = stamps.shape[ 0 ]
362
+ Py_ssize_t pos, i, n = stamps.size
363
+ int64_t utc_val, local_val, res_val
363
364
int64_t* tdata = NULL
364
365
365
366
npy_datetimestruct dts
366
- int64_t[::1 ] result = np.empty(n, dtype = np.int64)
367
+ ndarray result = cnp.PyArray_EMPTY(stamps.ndim, stamps.shape, cnp.NPY_INT64, 0 )
368
+ cnp.broadcast mi = cnp.PyArray_MultiIterNew2(result, stamps)
367
369
368
370
if info.use_dst:
369
371
tdata = < int64_t* > cnp.PyArray_DATA(info.trans)
370
372
371
373
for i in range (n):
372
- utc_val = stamps[i]
373
- if utc_val == NPY_NAT:
374
- result[i] = NPY_NAT
375
- continue
374
+ # Analogous to: utc_val = stamps[i]
375
+ utc_val = (< int64_t* > cnp.PyArray_MultiIter_DATA(mi, 1 ))[0 ]
376
376
377
- if info.use_utc:
378
- local_val = utc_val
379
- elif info.use_tzlocal:
380
- local_val = utc_val + localize_tzinfo_api(utc_val, tz)
381
- elif info.use_fixed:
382
- local_val = utc_val + info.delta
377
+ if utc_val == NPY_NAT:
378
+ res_val = NPY_NAT
383
379
else :
384
- pos = bisect_right_i8(tdata, utc_val, info.ntrans) - 1
385
- local_val = utc_val + info.deltas[pos]
380
+ if info.use_utc:
381
+ local_val = utc_val
382
+ elif info.use_tzlocal:
383
+ local_val = utc_val + localize_tzinfo_api(utc_val, tz)
384
+ elif info.use_fixed:
385
+ local_val = utc_val + info.delta
386
+ else :
387
+ pos = bisect_right_i8(tdata, utc_val, info.ntrans) - 1
388
+ local_val = utc_val + info.deltas[pos]
386
389
387
- dt64_to_dtstruct(local_val, & dts)
388
- result[i] = get_period_ordinal(& dts, freq)
390
+ dt64_to_dtstruct(local_val, & dts)
391
+ res_val = get_period_ordinal(& dts, freq)
392
+
393
+ # Analogous to: result[i] = res_val
394
+ (< int64_t* > cnp.PyArray_MultiIter_DATA(mi, 0 ))[0 ] = res_val
389
395
390
- return result.base # .base to get underlying ndarray
396
+ cnp.PyArray_MultiIter_NEXT(mi)
397
+
398
+ return result
0 commit comments