@@ -1195,6 +1195,7 @@ ctypedef fused diff_t:
1195
1195
ctypedef fused out_t:
1196
1196
float32_t
1197
1197
float64_t
1198
+ int64_t
1198
1199
1199
1200
1200
1201
@ cython.boundscheck (False )
@@ -1204,11 +1205,13 @@ def diff_2d(
1204
1205
ndarray[out_t , ndim = 2 ] out,
1205
1206
Py_ssize_t periods ,
1206
1207
int axis ,
1208
+ bint datetimelike = False ,
1207
1209
):
1208
1210
cdef:
1209
1211
Py_ssize_t i, j, sx, sy, start, stop
1210
1212
bint f_contig = arr.flags.f_contiguous
1211
1213
# bint f_contig = arr.is_f_contig() # TODO(cython 3)
1214
+ diff_t left, right
1212
1215
1213
1216
# Disable for unsupported dtype combinations,
1214
1217
# see https://github.com/cython/cython/issues/2646
@@ -1218,6 +1221,9 @@ def diff_2d(
1218
1221
elif (out_t is float64_t
1219
1222
and (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t)):
1220
1223
raise NotImplementedError
1224
+ elif out_t is int64_t and diff_t is not int64_t:
1225
+ # We only have out_t of int64_t if we have datetimelike
1226
+ raise NotImplementedError
1221
1227
else :
1222
1228
# We put this inside an indented else block to avoid cython build
1223
1229
# warnings about unreachable code
@@ -1231,15 +1237,31 @@ def diff_2d(
1231
1237
start, stop = 0 , sx + periods
1232
1238
for j in range (sy):
1233
1239
for i in range (start, stop):
1234
- out[i, j] = arr[i, j] - arr[i - periods, j]
1240
+ left = arr[i, j]
1241
+ right = arr[i - periods, j]
1242
+ if out_t is int64_t and datetimelike:
1243
+ if left == NPY_NAT or right == NPY_NAT:
1244
+ out[i, j] = NPY_NAT
1245
+ else :
1246
+ out[i, j] = left - right
1247
+ else :
1248
+ out[i, j] = left - right
1235
1249
else :
1236
1250
if periods >= 0 :
1237
1251
start, stop = periods, sy
1238
1252
else :
1239
1253
start, stop = 0 , sy + periods
1240
1254
for j in range (start, stop):
1241
1255
for i in range (sx):
1242
- out[i, j] = arr[i, j] - arr[i, j - periods]
1256
+ left = arr[i, j]
1257
+ right = arr[i, j - periods]
1258
+ if out_t is int64_t and datetimelike:
1259
+ if left == NPY_NAT or right == NPY_NAT:
1260
+ out[i, j] = NPY_NAT
1261
+ else :
1262
+ out[i, j] = left - right
1263
+ else :
1264
+ out[i, j] = left - right
1243
1265
else :
1244
1266
if axis == 0 :
1245
1267
if periods >= 0 :
@@ -1248,15 +1270,31 @@ def diff_2d(
1248
1270
start, stop = 0 , sx + periods
1249
1271
for i in range (start, stop):
1250
1272
for j in range (sy):
1251
- out[i, j] = arr[i, j] - arr[i - periods, j]
1273
+ left = arr[i, j]
1274
+ right = arr[i - periods, j]
1275
+ if out_t is int64_t and datetimelike:
1276
+ if left == NPY_NAT or right == NPY_NAT:
1277
+ out[i, j] = NPY_NAT
1278
+ else :
1279
+ out[i, j] = left - right
1280
+ else :
1281
+ out[i, j] = left - right
1252
1282
else :
1253
1283
if periods >= 0 :
1254
1284
start, stop = periods, sy
1255
1285
else :
1256
1286
start, stop = 0 , sy + periods
1257
1287
for i in range (sx):
1258
1288
for j in range (start, stop):
1259
- out[i, j] = arr[i, j] - arr[i, j - periods]
1289
+ left = arr[i, j]
1290
+ right = arr[i, j - periods]
1291
+ if out_t is int64_t and datetimelike:
1292
+ if left == NPY_NAT or right == NPY_NAT:
1293
+ out[i, j] = NPY_NAT
1294
+ else :
1295
+ out[i, j] = left - right
1296
+ else :
1297
+ out[i, j] = left - right
1260
1298
1261
1299
1262
1300
# generated from template
0 commit comments