@@ -38,7 +38,8 @@ from .tzconversion cimport Localizer
38
38
39
39
@ cython.boundscheck (False )
40
40
@ cython.wraparound (False )
41
- def tz_convert_from_utc (const int64_t[:] stamps , tzinfo tz ):
41
+ def tz_convert_from_utc (ndarray stamps , tzinfo tz ):
42
+ # stamps is int64_t, arbitrary ndim
42
43
"""
43
44
Convert the values (in i8) from UTC to tz
44
45
@@ -54,27 +55,33 @@ def tz_convert_from_utc(const int64_t[:] stamps, tzinfo tz):
54
55
cdef:
55
56
Localizer info = Localizer(tz)
56
57
int64_t utc_val, local_val
57
- Py_ssize_t pos, i, n = stamps.shape[ 0 ]
58
+ Py_ssize_t pos, i, n = stamps.size
58
59
59
- int64_t[::1 ] result
60
+ ndarray result
61
+ cnp.broadcast mi
60
62
61
63
if tz is None or is_utc(tz) or stamps.size == 0 :
62
64
# Much faster than going through the "standard" pattern below
63
- return stamps.base. copy()
65
+ return stamps.copy()
64
66
65
- result = np.empty(n, dtype = np.int64)
67
+ result = cnp.PyArray_EMPTY(stamps.ndim, stamps.shape, cnp.NPY_INT64, 0 )
68
+ mi = cnp.PyArray_MultiIterNew2(result, stamps)
66
69
67
70
for i in range (n):
68
- utc_val = stamps[i]
71
+ # Analogous to: utc_val = stamps[i]
72
+ utc_val = (< int64_t* > cnp.PyArray_MultiIter_DATA(mi, 1 ))[0 ]
73
+
69
74
if utc_val == NPY_NAT:
70
- result[i] = NPY_NAT
71
- continue
75
+ local_val = NPY_NAT
76
+ else :
77
+ local_val = info.utc_val_to_local_val(utc_val, & pos)
72
78
73
- local_val = info.utc_val_to_local_val(utc_val, & pos)
79
+ # Analogous to: result[i] = local_val
80
+ (< int64_t* > cnp.PyArray_MultiIter_DATA(mi, 0 ))[0 ] = local_val
74
81
75
- result[i] = local_val
82
+ cnp.PyArray_MultiIter_NEXT(mi)
76
83
77
- return result.base
84
+ return result
78
85
79
86
80
87
# -------------------------------------------------------------------------
@@ -83,12 +90,13 @@ def tz_convert_from_utc(const int64_t[:] stamps, tzinfo tz):
83
90
@ cython.wraparound (False )
84
91
@ cython.boundscheck (False )
85
92
def ints_to_pydatetime (
86
- const int64_t[:] stamps ,
93
+ ndarray stamps ,
87
94
tzinfo tz = None ,
88
95
BaseOffset freq = None ,
89
96
bint fold = False ,
90
97
str box = " datetime"
91
98
) -> np.ndarray:
99
+ # stamps is int64 , arbitrary ndim
92
100
"""
93
101
Convert an i8 repr to an ndarray of datetimes , date , time or Timestamp.
94
102
@@ -119,13 +127,21 @@ def ints_to_pydatetime(
119
127
cdef:
120
128
Localizer info = Localizer(tz)
121
129
int64_t utc_val , local_val
122
- Py_ssize_t i , n = stamps.shape[ 0 ]
130
+ Py_ssize_t i , n = stamps.size
123
131
Py_ssize_t pos = - 1 # unused, avoid not-initialized warning
124
132
125
133
npy_datetimestruct dts
126
134
tzinfo new_tz
127
- ndarray[object] result = np.empty(n, dtype = object )
128
135
bint use_date = False , use_time = False , use_ts = False , use_pydt = False
136
+ object res_val
137
+
138
+ # Note that `result` (and thus `result_flat`) is C-order and
139
+ # `it` iterates C-order as well , so the iteration matches
140
+ # See discussion at
141
+ # github.com/pandas-dev/pandas/pull/46886#discussion_r860261305
142
+ ndarray result = cnp.PyArray_EMPTY(stamps.ndim, stamps.shape, cnp.NPY_OBJECT, 0 )
143
+ object[::1] res_flat = result.ravel() # should NOT be a copy
144
+ cnp.flatiter it = cnp.PyArray_IterNew(stamps)
129
145
130
146
if box == "date":
131
147
assert (tz is None ), "tz should be None when converting to date"
@@ -142,31 +158,44 @@ def ints_to_pydatetime(
142
158
)
143
159
144
160
for i in range(n ):
145
- utc_val = stamps[i]
161
+ # Analogous to: utc_val = stamps[i]
162
+ utc_val = (< int64_t* > cnp.PyArray_ITER_DATA(it))[0 ]
163
+
146
164
new_tz = tz
147
165
148
166
if utc_val == NPY_NAT:
149
- result[i] = < object > NaT
150
- continue
167
+ res_val = < object > NaT
151
168
152
- local_val = info.utc_val_to_local_val(utc_val, & pos)
153
- if info.use_pytz:
154
- # find right representation of dst etc in pytz timezone
155
- new_tz = tz._tzinfos[tz._transition_info[pos]]
156
-
157
- dt64_to_dtstruct(local_val, & dts)
158
-
159
- if use_ts:
160
- result[i] = create_timestamp_from_ts(utc_val, dts, new_tz, freq, fold)
161
- elif use_pydt:
162
- result[i] = datetime(
163
- dts.year, dts.month, dts.day, dts.hour, dts.min, dts.sec, dts.us,
164
- new_tz, fold = fold,
165
- )
166
- elif use_date:
167
- result[i] = date(dts.year, dts.month, dts.day)
168
169
else :
169
- result[i] = time(dts.hour, dts.min, dts.sec, dts.us, new_tz, fold = fold)
170
+
171
+ local_val = info.utc_val_to_local_val(utc_val, & pos)
172
+ if info.use_pytz:
173
+ # find right representation of dst etc in pytz timezone
174
+ new_tz = tz._tzinfos[tz._transition_info[pos]]
175
+
176
+ dt64_to_dtstruct(local_val, & dts)
177
+
178
+ if use_ts:
179
+ res_val = create_timestamp_from_ts(utc_val, dts, new_tz, freq, fold)
180
+ elif use_pydt:
181
+ res_val = datetime(
182
+ dts.year, dts.month, dts.day, dts.hour, dts.min, dts.sec, dts.us,
183
+ new_tz, fold = fold,
184
+ )
185
+ elif use_date:
186
+ res_val = date(dts.year, dts.month, dts.day)
187
+ else :
188
+ res_val = time(dts.hour, dts.min, dts.sec, dts.us, new_tz, fold = fold)
189
+
190
+ # Note: we can index result directly instead of using PyArray_MultiIter_DATA
191
+ # like we do for the other functions because result is known C-contiguous
192
+ # and is the first argument to PyArray_MultiIterNew2. The usual pattern
193
+ # does not seem to work with object dtype.
194
+ # See discussion at
195
+ # github.com/pandas-dev/pandas/pull/46886#discussion_r860261305
196
+ res_flat[i] = res_val
197
+
198
+ cnp.PyArray_ITER_NEXT(it)
170
199
171
200
return result
172
201
@@ -190,27 +219,33 @@ cdef inline c_Resolution _reso_stamp(npy_datetimestruct *dts):
190
219
191
220
@ cython.wraparound (False )
192
221
@ cython.boundscheck (False )
193
- def get_resolution (const int64_t[:] stamps , tzinfo tz = None ) -> Resolution:
222
+ def get_resolution (ndarray stamps , tzinfo tz = None ) -> Resolution:
223
+ # stamps is int64_t , any ndim
194
224
cdef:
195
225
Localizer info = Localizer(tz)
196
226
int64_t utc_val , local_val
197
- Py_ssize_t i , n = stamps.shape[ 0 ]
227
+ Py_ssize_t i , n = stamps.size
198
228
Py_ssize_t pos = - 1 # unused, avoid not-initialized warning
229
+ cnp.flatiter it = cnp.PyArray_IterNew(stamps)
199
230
200
231
npy_datetimestruct dts
201
232
c_Resolution reso = c_Resolution.RESO_DAY, curr_reso
202
233
203
234
for i in range(n ):
204
- utc_val = stamps[i]
235
+ # Analogous to: utc_val = stamps[i]
236
+ utc_val = cnp.PyArray_GETITEM(stamps, cnp.PyArray_ITER_DATA(it))
237
+
205
238
if utc_val == NPY_NAT:
206
- continue
239
+ pass
240
+ else :
241
+ local_val = info.utc_val_to_local_val(utc_val, & pos)
207
242
208
- local_val = info.utc_val_to_local_val(utc_val, & pos)
243
+ dt64_to_dtstruct(local_val, & dts)
244
+ curr_reso = _reso_stamp(& dts)
245
+ if curr_reso < reso:
246
+ reso = curr_reso
209
247
210
- dt64_to_dtstruct(local_val, & dts)
211
- curr_reso = _reso_stamp(& dts)
212
- if curr_reso < reso:
213
- reso = curr_reso
248
+ cnp.PyArray_ITER_NEXT(it)
214
249
215
250
return Resolution(reso)
216
251
@@ -221,7 +256,8 @@ def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
221
256
@ cython.cdivision (False )
222
257
@ cython.wraparound (False )
223
258
@ cython.boundscheck (False )
224
- cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo tz):
259
+ cpdef ndarray normalize_i8_timestamps(ndarray stamps, tzinfo tz):
260
+ # stamps is int64_t, arbitrary ndim
225
261
"""
226
262
Normalize each of the (nanosecond) timezone aware timestamps in the given
227
263
array by rounding down to the beginning of the day (i.e. midnight).
@@ -238,28 +274,35 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
238
274
"""
239
275
cdef:
240
276
Localizer info = Localizer(tz)
241
- int64_t utc_val, local_val
242
- Py_ssize_t i, n = stamps.shape[ 0 ]
277
+ int64_t utc_val, local_val, res_val
278
+ Py_ssize_t i, n = stamps.size
243
279
Py_ssize_t pos = - 1 # unused, avoid not-initialized warning
244
280
245
- int64_t[::1 ] result = np.empty(n, dtype = np.int64)
281
+ ndarray result = cnp.PyArray_EMPTY(stamps.ndim, stamps.shape, cnp.NPY_INT64, 0 )
282
+ cnp.broadcast mi = cnp.PyArray_MultiIterNew2(result, stamps)
246
283
247
284
for i in range (n):
248
- utc_val = stamps[i]
285
+ # Analogous to: utc_val = stamps[i]
286
+ utc_val = (< int64_t* > cnp.PyArray_MultiIter_DATA(mi, 1 ))[0 ]
287
+
249
288
if utc_val == NPY_NAT:
250
- result[i] = NPY_NAT
251
- continue
289
+ res_val = NPY_NAT
290
+ else :
291
+ local_val = info.utc_val_to_local_val(utc_val, & pos)
292
+ res_val = local_val - (local_val % DAY_NANOS)
252
293
253
- local_val = info.utc_val_to_local_val(utc_val, & pos)
294
+ # Analogous to: result[i] = res_val
295
+ (< int64_t* > cnp.PyArray_MultiIter_DATA(mi, 0 ))[0 ] = res_val
254
296
255
- result[i] = local_val - (local_val % DAY_NANOS )
297
+ cnp.PyArray_MultiIter_NEXT(mi )
256
298
257
- return result.base # `.base` to access underlying ndarray
299
+ return result
258
300
259
301
260
302
@ cython.wraparound (False )
261
303
@ cython.boundscheck (False )
262
- def is_date_array_normalized (const int64_t[:] stamps , tzinfo tz = None ) -> bool:
304
+ def is_date_array_normalized (ndarray stamps , tzinfo tz = None ) -> bool:
305
+ # stamps is int64_t , arbitrary ndim
263
306
"""
264
307
Check if all of the given (nanosecond ) timestamps are normalized to
265
308
midnight , i.e. hour == minute == second == 0. If the optional timezone
@@ -277,16 +320,21 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
277
320
cdef:
278
321
Localizer info = Localizer(tz)
279
322
int64_t utc_val , local_val
280
- Py_ssize_t i , n = stamps.shape[ 0 ]
323
+ Py_ssize_t i , n = stamps.size
281
324
Py_ssize_t pos = - 1 # unused, avoid not-initialized warning
325
+ cnp.flatiter it = cnp.PyArray_IterNew(stamps)
282
326
283
327
for i in range(n ):
284
- utc_val = stamps[i]
328
+ # Analogous to: utc_val = stamps[i]
329
+ utc_val = cnp.PyArray_GETITEM(stamps, cnp.PyArray_ITER_DATA(it))
330
+
285
331
local_val = info.utc_val_to_local_val(utc_val, & pos)
286
332
287
333
if local_val % DAY_NANOS != 0 :
288
334
return False
289
335
336
+ cnp.PyArray_ITER_NEXT(it)
337
+
290
338
return True
291
339
292
340
0 commit comments