Skip to content

Commit d86e200

Browse files
authored
REF: handle 2D in tslibs.vectorized (#46886)
1 parent b19c203 commit d86e200

File tree

3 files changed

+111
-62
lines changed

3 files changed

+111
-62
lines changed

pandas/_libs/tslibs/timedeltas.pyx

+1
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def array_to_timedelta64(
354354
raise ValueError(
355355
"unit must not be specified if the input contains a str"
356356
)
357+
cnp.PyArray_ITER_NEXT(it)
357358

358359
# Usually, we have all strings. If so, we hit the fast path.
359360
# If this path fails, we try conversion a different way, and

pandas/_libs/tslibs/vectorized.pyi

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,24 @@ from pandas._libs.tslibs.offsets import BaseOffset
1111
from pandas._typing import npt
1212

1313
def dt64arr_to_periodarr(
14-
stamps: npt.NDArray[np.int64], # const int64_t[:]
14+
stamps: npt.NDArray[np.int64],
1515
freq: int,
1616
tz: tzinfo | None,
17-
) -> npt.NDArray[np.int64]: ... # np.ndarray[np.int64, ndim=1]
17+
) -> npt.NDArray[np.int64]: ...
1818
def is_date_array_normalized(
19-
stamps: npt.NDArray[np.int64], # const int64_t[:]
19+
stamps: npt.NDArray[np.int64],
2020
tz: tzinfo | None = ...,
2121
) -> bool: ...
2222
def normalize_i8_timestamps(
23-
stamps: npt.NDArray[np.int64], # const int64_t[:]
23+
stamps: npt.NDArray[np.int64],
2424
tz: tzinfo | None,
2525
) -> npt.NDArray[np.int64]: ...
2626
def get_resolution(
27-
stamps: npt.NDArray[np.int64], # const int64_t[:]
27+
stamps: npt.NDArray[np.int64],
2828
tz: tzinfo | None = ...,
2929
) -> Resolution: ...
3030
def ints_to_pydatetime(
31-
arr: npt.NDArray[np.int64], # const int64_t[:}]
31+
arr: npt.NDArray[np.int64],
3232
tz: tzinfo | None = ...,
3333
freq: BaseOffset | None = ...,
3434
fold: bool = ...,

pandas/_libs/tslibs/vectorized.pyx

+104-56
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ from .tzconversion cimport Localizer
3838

3939
@cython.boundscheck(False)
4040
@cython.wraparound(False)
41-
def tz_convert_from_utc(const int64_t[:] stamps, tzinfo tz):
41+
def tz_convert_from_utc(ndarray stamps, tzinfo tz):
42+
# stamps is int64_t, arbitrary ndim
4243
"""
4344
Convert the values (in i8) from UTC to tz
4445
@@ -54,27 +55,33 @@ def tz_convert_from_utc(const int64_t[:] stamps, tzinfo tz):
5455
cdef:
5556
Localizer info = Localizer(tz)
5657
int64_t utc_val, local_val
57-
Py_ssize_t pos, i, n = stamps.shape[0]
58+
Py_ssize_t pos, i, n = stamps.size
5859

59-
int64_t[::1] result
60+
ndarray result
61+
cnp.broadcast mi
6062

6163
if tz is None or is_utc(tz) or stamps.size == 0:
6264
# Much faster than going through the "standard" pattern below
63-
return stamps.base.copy()
65+
return stamps.copy()
6466

65-
result = np.empty(n, dtype=np.int64)
67+
result = cnp.PyArray_EMPTY(stamps.ndim, stamps.shape, cnp.NPY_INT64, 0)
68+
mi = cnp.PyArray_MultiIterNew2(result, stamps)
6669

6770
for i in range(n):
68-
utc_val = stamps[i]
71+
# Analogous to: utc_val = stamps[i]
72+
utc_val = (<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 1))[0]
73+
6974
if utc_val == NPY_NAT:
70-
result[i] = NPY_NAT
71-
continue
75+
local_val = NPY_NAT
76+
else:
77+
local_val = info.utc_val_to_local_val(utc_val, &pos)
7278

73-
local_val = info.utc_val_to_local_val(utc_val, &pos)
79+
# Analogous to: result[i] = local_val
80+
(<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 0))[0] = local_val
7481

75-
result[i] = local_val
82+
cnp.PyArray_MultiIter_NEXT(mi)
7683

77-
return result.base
84+
return result
7885

7986

8087
# -------------------------------------------------------------------------
@@ -83,12 +90,13 @@ def tz_convert_from_utc(const int64_t[:] stamps, tzinfo tz):
8390
@cython.wraparound(False)
8491
@cython.boundscheck(False)
8592
def ints_to_pydatetime(
86-
const int64_t[:] stamps,
93+
ndarray stamps,
8794
tzinfo tz=None,
8895
BaseOffset freq=None,
8996
bint fold=False,
9097
str box="datetime"
9198
) -> np.ndarray:
99+
# stamps is int64, arbitrary ndim
92100
"""
93101
Convert an i8 repr to an ndarray of datetimes, date, time or Timestamp.
94102

@@ -119,13 +127,21 @@ def ints_to_pydatetime(
119127
cdef:
120128
Localizer info = Localizer(tz)
121129
int64_t utc_val, local_val
122-
Py_ssize_t i, n = stamps.shape[0]
130+
Py_ssize_t i, n = stamps.size
123131
Py_ssize_t pos = -1 # unused, avoid not-initialized warning
124132

125133
npy_datetimestruct dts
126134
tzinfo new_tz
127-
ndarray[object] result = np.empty(n, dtype=object)
128135
bint use_date = False, use_time = False, use_ts = False, use_pydt = False
136+
object res_val
137+
138+
# Note that `result` (and thus `result_flat`) is C-order and
139+
# `it` iterates C-order as well, so the iteration matches
140+
# See discussion at
141+
# github.com/pandas-dev/pandas/pull/46886#discussion_r860261305
142+
ndarray result = cnp.PyArray_EMPTY(stamps.ndim, stamps.shape, cnp.NPY_OBJECT, 0)
143+
object[::1] res_flat = result.ravel() # should NOT be a copy
144+
cnp.flatiter it = cnp.PyArray_IterNew(stamps)
129145

130146
if box == "date":
131147
assert (tz is None), "tz should be None when converting to date"
@@ -142,31 +158,44 @@ def ints_to_pydatetime(
142158
)
143159

144160
for i in range(n):
145-
utc_val = stamps[i]
161+
# Analogous to: utc_val = stamps[i]
162+
utc_val = (<int64_t*>cnp.PyArray_ITER_DATA(it))[0]
163+
146164
new_tz = tz
147165

148166
if utc_val == NPY_NAT:
149-
result[i] = <object>NaT
150-
continue
167+
res_val = <object>NaT
151168

152-
local_val = info.utc_val_to_local_val(utc_val, &pos)
153-
if info.use_pytz:
154-
# find right representation of dst etc in pytz timezone
155-
new_tz = tz._tzinfos[tz._transition_info[pos]]
156-
157-
dt64_to_dtstruct(local_val, &dts)
158-
159-
if use_ts:
160-
result[i] = create_timestamp_from_ts(utc_val, dts, new_tz, freq, fold)
161-
elif use_pydt:
162-
result[i] = datetime(
163-
dts.year, dts.month, dts.day, dts.hour, dts.min, dts.sec, dts.us,
164-
new_tz, fold=fold,
165-
)
166-
elif use_date:
167-
result[i] = date(dts.year, dts.month, dts.day)
168169
else:
169-
result[i] = time(dts.hour, dts.min, dts.sec, dts.us, new_tz, fold=fold)
170+
171+
local_val = info.utc_val_to_local_val(utc_val, &pos)
172+
if info.use_pytz:
173+
# find right representation of dst etc in pytz timezone
174+
new_tz = tz._tzinfos[tz._transition_info[pos]]
175+
176+
dt64_to_dtstruct(local_val, &dts)
177+
178+
if use_ts:
179+
res_val = create_timestamp_from_ts(utc_val, dts, new_tz, freq, fold)
180+
elif use_pydt:
181+
res_val = datetime(
182+
dts.year, dts.month, dts.day, dts.hour, dts.min, dts.sec, dts.us,
183+
new_tz, fold=fold,
184+
)
185+
elif use_date:
186+
res_val = date(dts.year, dts.month, dts.day)
187+
else:
188+
res_val = time(dts.hour, dts.min, dts.sec, dts.us, new_tz, fold=fold)
189+
190+
# Note: we can index result directly instead of using PyArray_MultiIter_DATA
191+
# like we do for the other functions because result is known C-contiguous
192+
# and is the first argument to PyArray_MultiIterNew2. The usual pattern
193+
# does not seem to work with object dtype.
194+
# See discussion at
195+
# github.com/pandas-dev/pandas/pull/46886#discussion_r860261305
196+
res_flat[i] = res_val
197+
198+
cnp.PyArray_ITER_NEXT(it)
170199

171200
return result
172201

@@ -190,27 +219,33 @@ cdef inline c_Resolution _reso_stamp(npy_datetimestruct *dts):
190219

191220
@cython.wraparound(False)
192221
@cython.boundscheck(False)
193-
def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
222+
def get_resolution(ndarray stamps, tzinfo tz=None) -> Resolution:
223+
# stamps is int64_t, any ndim
194224
cdef:
195225
Localizer info = Localizer(tz)
196226
int64_t utc_val, local_val
197-
Py_ssize_t i, n = stamps.shape[0]
227+
Py_ssize_t i, n = stamps.size
198228
Py_ssize_t pos = -1 # unused, avoid not-initialized warning
229+
cnp.flatiter it = cnp.PyArray_IterNew(stamps)
199230

200231
npy_datetimestruct dts
201232
c_Resolution reso = c_Resolution.RESO_DAY, curr_reso
202233

203234
for i in range(n):
204-
utc_val = stamps[i]
235+
# Analogous to: utc_val = stamps[i]
236+
utc_val = cnp.PyArray_GETITEM(stamps, cnp.PyArray_ITER_DATA(it))
237+
205238
if utc_val == NPY_NAT:
206-
continue
239+
pass
240+
else:
241+
local_val = info.utc_val_to_local_val(utc_val, &pos)
207242

208-
local_val = info.utc_val_to_local_val(utc_val, &pos)
243+
dt64_to_dtstruct(local_val, &dts)
244+
curr_reso = _reso_stamp(&dts)
245+
if curr_reso < reso:
246+
reso = curr_reso
209247

210-
dt64_to_dtstruct(local_val, &dts)
211-
curr_reso = _reso_stamp(&dts)
212-
if curr_reso < reso:
213-
reso = curr_reso
248+
cnp.PyArray_ITER_NEXT(it)
214249

215250
return Resolution(reso)
216251

@@ -221,7 +256,8 @@ def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
221256
@cython.cdivision(False)
222257
@cython.wraparound(False)
223258
@cython.boundscheck(False)
224-
cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo tz):
259+
cpdef ndarray normalize_i8_timestamps(ndarray stamps, tzinfo tz):
260+
# stamps is int64_t, arbitrary ndim
225261
"""
226262
Normalize each of the (nanosecond) timezone aware timestamps in the given
227263
array by rounding down to the beginning of the day (i.e. midnight).
@@ -238,28 +274,35 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
238274
"""
239275
cdef:
240276
Localizer info = Localizer(tz)
241-
int64_t utc_val, local_val
242-
Py_ssize_t i, n = stamps.shape[0]
277+
int64_t utc_val, local_val, res_val
278+
Py_ssize_t i, n = stamps.size
243279
Py_ssize_t pos = -1 # unused, avoid not-initialized warning
244280

245-
int64_t[::1] result = np.empty(n, dtype=np.int64)
281+
ndarray result = cnp.PyArray_EMPTY(stamps.ndim, stamps.shape, cnp.NPY_INT64, 0)
282+
cnp.broadcast mi = cnp.PyArray_MultiIterNew2(result, stamps)
246283

247284
for i in range(n):
248-
utc_val = stamps[i]
285+
# Analogous to: utc_val = stamps[i]
286+
utc_val = (<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 1))[0]
287+
249288
if utc_val == NPY_NAT:
250-
result[i] = NPY_NAT
251-
continue
289+
res_val = NPY_NAT
290+
else:
291+
local_val = info.utc_val_to_local_val(utc_val, &pos)
292+
res_val = local_val - (local_val % DAY_NANOS)
252293

253-
local_val = info.utc_val_to_local_val(utc_val, &pos)
294+
# Analogous to: result[i] = res_val
295+
(<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 0))[0] = res_val
254296

255-
result[i] = local_val - (local_val % DAY_NANOS)
297+
cnp.PyArray_MultiIter_NEXT(mi)
256298

257-
return result.base # `.base` to access underlying ndarray
299+
return result
258300

259301

260302
@cython.wraparound(False)
261303
@cython.boundscheck(False)
262-
def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
304+
def is_date_array_normalized(ndarray stamps, tzinfo tz=None) -> bool:
305+
# stamps is int64_t, arbitrary ndim
263306
"""
264307
Check if all of the given (nanosecond) timestamps are normalized to
265308
midnight, i.e. hour == minute == second == 0. If the optional timezone
@@ -277,16 +320,21 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
277320
cdef:
278321
Localizer info = Localizer(tz)
279322
int64_t utc_val, local_val
280-
Py_ssize_t i, n = stamps.shape[0]
323+
Py_ssize_t i, n = stamps.size
281324
Py_ssize_t pos = -1 # unused, avoid not-initialized warning
325+
cnp.flatiter it = cnp.PyArray_IterNew(stamps)
282326

283327
for i in range(n):
284-
utc_val = stamps[i]
328+
# Analogous to: utc_val = stamps[i]
329+
utc_val = cnp.PyArray_GETITEM(stamps, cnp.PyArray_ITER_DATA(it))
330+
285331
local_val = info.utc_val_to_local_val(utc_val, &pos)
286332

287333
if local_val % DAY_NANOS != 0:
288334
return False
289335

336+
cnp.PyArray_ITER_NEXT(it)
337+
290338
return True
291339

292340

0 commit comments

Comments
 (0)