Skip to content

Commit 0013472

Browse files
authored
BUG: algos.diff with datetimelike and NaT (#37140)
1 parent 56f03c7 commit 0013472

File tree

3 files changed

+82
-7
lines changed

3 files changed

+82
-7
lines changed

pandas/_libs/algos.pyx

+42-4
Original file line numberDiff line numberDiff line change
@@ -1195,6 +1195,7 @@ ctypedef fused diff_t:
11951195
ctypedef fused out_t:
11961196
float32_t
11971197
float64_t
1198+
int64_t
11981199

11991200

12001201
@cython.boundscheck(False)
@@ -1204,11 +1205,13 @@ def diff_2d(
12041205
ndarray[out_t, ndim=2] out,
12051206
Py_ssize_t periods,
12061207
int axis,
1208+
bint datetimelike=False,
12071209
):
12081210
cdef:
12091211
Py_ssize_t i, j, sx, sy, start, stop
12101212
bint f_contig = arr.flags.f_contiguous
12111213
# bint f_contig = arr.is_f_contig() # TODO(cython 3)
1214+
diff_t left, right
12121215

12131216
# Disable for unsupported dtype combinations,
12141217
# see https://github.com/cython/cython/issues/2646
@@ -1218,6 +1221,9 @@ def diff_2d(
12181221
elif (out_t is float64_t
12191222
and (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t)):
12201223
raise NotImplementedError
1224+
elif out_t is int64_t and diff_t is not int64_t:
1225+
# We only have out_t of int64_t if we have datetimelike
1226+
raise NotImplementedError
12211227
else:
12221228
# We put this inside an indented else block to avoid cython build
12231229
# warnings about unreachable code
@@ -1231,15 +1237,31 @@ def diff_2d(
12311237
start, stop = 0, sx + periods
12321238
for j in range(sy):
12331239
for i in range(start, stop):
1234-
out[i, j] = arr[i, j] - arr[i - periods, j]
1240+
left = arr[i, j]
1241+
right = arr[i - periods, j]
1242+
if out_t is int64_t and datetimelike:
1243+
if left == NPY_NAT or right == NPY_NAT:
1244+
out[i, j] = NPY_NAT
1245+
else:
1246+
out[i, j] = left - right
1247+
else:
1248+
out[i, j] = left - right
12351249
else:
12361250
if periods >= 0:
12371251
start, stop = periods, sy
12381252
else:
12391253
start, stop = 0, sy + periods
12401254
for j in range(start, stop):
12411255
for i in range(sx):
1242-
out[i, j] = arr[i, j] - arr[i, j - periods]
1256+
left = arr[i, j]
1257+
right = arr[i, j - periods]
1258+
if out_t is int64_t and datetimelike:
1259+
if left == NPY_NAT or right == NPY_NAT:
1260+
out[i, j] = NPY_NAT
1261+
else:
1262+
out[i, j] = left - right
1263+
else:
1264+
out[i, j] = left - right
12431265
else:
12441266
if axis == 0:
12451267
if periods >= 0:
@@ -1248,15 +1270,31 @@ def diff_2d(
12481270
start, stop = 0, sx + periods
12491271
for i in range(start, stop):
12501272
for j in range(sy):
1251-
out[i, j] = arr[i, j] - arr[i - periods, j]
1273+
left = arr[i, j]
1274+
right = arr[i - periods, j]
1275+
if out_t is int64_t and datetimelike:
1276+
if left == NPY_NAT or right == NPY_NAT:
1277+
out[i, j] = NPY_NAT
1278+
else:
1279+
out[i, j] = left - right
1280+
else:
1281+
out[i, j] = left - right
12521282
else:
12531283
if periods >= 0:
12541284
start, stop = periods, sy
12551285
else:
12561286
start, stop = 0, sy + periods
12571287
for i in range(sx):
12581288
for j in range(start, stop):
1259-
out[i, j] = arr[i, j] - arr[i, j - periods]
1289+
left = arr[i, j]
1290+
right = arr[i, j - periods]
1291+
if out_t is int64_t and datetimelike:
1292+
if left == NPY_NAT or right == NPY_NAT:
1293+
out[i, j] = NPY_NAT
1294+
else:
1295+
out[i, j] = left - right
1296+
else:
1297+
out[i, j] = left - right
12601298

12611299

12621300
# generated from template

pandas/core/algorithms.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,8 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
19111911

19121912
if is_extension_array_dtype(dtype):
19131913
if hasattr(arr, f"__{op.__name__}__"):
1914+
if axis != 0:
1915+
raise ValueError(f"cannot diff {type(arr).__name__} on axis={axis}")
19141916
return op(arr, arr.shift(n))
19151917
else:
19161918
warn(
@@ -1925,18 +1927,26 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
19251927
is_timedelta = False
19261928
is_bool = False
19271929
if needs_i8_conversion(arr.dtype):
1928-
dtype = np.float64
1930+
dtype = np.int64
19291931
arr = arr.view("i8")
19301932
na = iNaT
19311933
is_timedelta = True
19321934

19331935
elif is_bool_dtype(dtype):
1936+
# We have to cast in order to be able to hold np.nan
19341937
dtype = np.object_
19351938
is_bool = True
19361939

19371940
elif is_integer_dtype(dtype):
1941+
# We have to cast in order to be able to hold np.nan
19381942
dtype = np.float64
19391943

1944+
orig_ndim = arr.ndim
1945+
if orig_ndim == 1:
1946+
# reshape so we can always use algos.diff_2d
1947+
arr = arr.reshape(-1, 1)
1948+
# TODO: require axis == 0
1949+
19401950
dtype = np.dtype(dtype)
19411951
out_arr = np.empty(arr.shape, dtype=dtype)
19421952

@@ -1947,7 +1957,7 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
19471957
if arr.ndim == 2 and arr.dtype.name in _diff_special:
19481958
# TODO: can diff_2d dtype specialization troubles be fixed by defining
19491959
# out_arr inside diff_2d?
1950-
algos.diff_2d(arr, out_arr, n, axis)
1960+
algos.diff_2d(arr, out_arr, n, axis, datetimelike=is_timedelta)
19511961
else:
19521962
# To keep mypy happy, _res_indexer is a list while res_indexer is
19531963
# a tuple, ditto for lag_indexer.
@@ -1981,8 +1991,10 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
19811991
out_arr[res_indexer] = arr[res_indexer] - arr[lag_indexer]
19821992

19831993
if is_timedelta:
1984-
out_arr = out_arr.astype("int64").view("timedelta64[ns]")
1994+
out_arr = out_arr.view("timedelta64[ns]")
19851995

1996+
if orig_ndim == 1:
1997+
out_arr = out_arr[:, 0]
19861998
return out_arr
19871999

19882000

pandas/tests/test_algos.py

+25
Original file line numberDiff line numberDiff line change
@@ -2405,3 +2405,28 @@ def test_index(self):
24052405
dtype="timedelta64[ns]",
24062406
)
24072407
tm.assert_series_equal(algos.mode(idx), exp)
2408+
2409+
2410+
class TestDiff:
2411+
@pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"])
2412+
def test_diff_datetimelike_nat(self, dtype):
2413+
# NaT - NaT is NaT, not 0
2414+
arr = np.arange(12).astype(np.int64).view(dtype).reshape(3, 4)
2415+
arr[:, 2] = arr.dtype.type("NaT", "ns")
2416+
result = algos.diff(arr, 1, axis=0)
2417+
2418+
expected = np.ones(arr.shape, dtype="timedelta64[ns]") * 4
2419+
expected[:, 2] = np.timedelta64("NaT", "ns")
2420+
expected[0, :] = np.timedelta64("NaT", "ns")
2421+
2422+
tm.assert_numpy_array_equal(result, expected)
2423+
2424+
result = algos.diff(arr.T, 1, axis=1)
2425+
tm.assert_numpy_array_equal(result, expected.T)
2426+
2427+
def test_diff_ea_axis(self):
2428+
dta = pd.date_range("2016-01-01", periods=3, tz="US/Pacific")._data
2429+
2430+
msg = "cannot diff DatetimeArray on axis=1"
2431+
with pytest.raises(ValueError, match=msg):
2432+
algos.diff(dta, 1, axis=1)

0 commit comments

Comments
 (0)