diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index ec2df060..7df439f5 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -359,7 +359,7 @@ def test_eye(n_rows, n_cols, kw): [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)], dtype=out.dtype # Note: dtype already checked above. ) - if expected.size == 0: + if 0 in expected.shape: expected = xp.reshape(expected, (n_rows, _n_cols)) ph.assert_array_elements("eye", out=out, expected=expected, kw=kw)