Skip to content

Commit e9f516e

Browse files
committed
ENH: Add a basic diff test
1 parent 31eec9d commit e9f516e

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

array_api_tests/test_utility_functions.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,40 @@ def test_any(x, data):
6363
expected = any(elements)
6464
ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx,
6565
out=result, expected=expected, kw=kw)
66+
67+
68+
@pytest.mark.unvectorized
69+
@pytest.mark.min_version("2024.12")
70+
@given(
71+
x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
72+
data=st.data(),
73+
)
74+
def test_diff(x, data):
75+
# TODO:
76+
# 1. append/prepend
77+
axis = data.draw(
78+
st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
79+
label="axis"
80+
)
81+
if axis is None:
82+
axis_kw = {"axis": -1}
83+
n_axis = x.ndim - 1
84+
else:
85+
axis_kw = {"axis": axis}
86+
n_axis = axis + x.ndim if axis < 0 else axis
87+
88+
n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
89+
90+
out = xp.diff(x, **axis_kw, n=n)
91+
92+
expected_shape = list(x.shape)
93+
expected_shape[n_axis] -= n
94+
assert out.shape == tuple(expected_shape)
95+
96+
# value test
97+
if n == 1:
98+
for idx in sh.ndindex(out.shape):
99+
l = list(idx)
100+
l[n_axis] += 1
101+
assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }"
102+

0 commit comments

Comments
 (0)