Skip to content

Commit 243dfea

Browse files
committed
test_eye: use assert_array_elements utility
1 parent 81d088d commit 243dfea

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

array_api_tests/test_creation_functions.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -354,14 +354,14 @@ def test_eye(n_rows, n_cols, kw):
354354
ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype)
355355
_n_cols = n_rows if n_cols is None else n_cols
356356
ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols))
357-
f_func = f"[eye({n_rows=}, {n_cols=})]"
358-
for i in range(n_rows):
359-
for j in range(_n_cols):
360-
f_indexed_out = f"out[{i}, {j}]={out[i, j]}"
361-
if j - i == kw.get("k", 0):
362-
assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}"
363-
else:
364-
assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}"
357+
k = kw.get("k", 0)
358+
expected = xp.asarray(
359+
[[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)],
360+
dtype=out.dtype # Note: dtype already checked above.
361+
)
362+
if expected.size == 0:
363+
expected = xp.reshape(expected, (n_rows, _n_cols))
364+
ph.assert_array_elements("eye", out=out, expected=expected, kw=kw)
365365

366366

367367
default_unsafe_dtypes = [xp.uint64]

0 commit comments

Comments
 (0)