Skip to content

Commit eba1f35

Browse files
asmeurerhonno
authored andcommitted
Update test_cross to use _test_stacks
1 parent d97fc9f commit eba1f35

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

array_api_tests/test_linalg.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
from hypothesis import assume, given
1818
from hypothesis.strategies import (booleans, composite, none, tuples, integers,
19-
shared, sampled_from, data, just)
19+
shared, sampled_from, one_of, data, just)
2020
from ndindex import iter_indices
2121

2222
from .array_helpers import assert_exactly_equal, asarray
@@ -29,7 +29,6 @@
2929
SQRT_MAX_ARRAY_SIZE, finite_matrices)
3030
from . import dtype_helpers as dh
3131
from . import pytest_helpers as ph
32-
from . import shape_helpers as sh
3332

3433
from . import _array_module
3534
from . import _array_module as xp
@@ -162,26 +161,18 @@ def test_cross(x1_x2_kw):
162161
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
163162
assert res.shape == shape, "cross() did not return the correct shape"
164163

165-
# cross is too different from other functions to use _test_stacks, and it
166-
# is the only function that works the way it does, so it's not really
167-
# worth generalizing _test_stacks to handle it.
168-
a = axis if axis >= 0 else axis + len(shape)
169-
for _idx in sh.ndindex(shape[:a] + shape[a+1:]):
170-
idx = _idx[:a] + (slice(None),) + _idx[a:]
171-
assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite."
172-
res_stack = res[idx]
173-
x1_stack = x1[idx]
174-
x2_stack = x2[idx]
175-
assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
176-
decomp_res_stack = linalg.cross(x1_stack, x2_stack)
177-
assert_exactly_equal(res_stack, decomp_res_stack)
178-
179-
exact_cross = asarray([
180-
x1_stack[1]*x2_stack[2] - x1_stack[2]*x2_stack[1],
181-
x1_stack[2]*x2_stack[0] - x1_stack[0]*x2_stack[2],
182-
x1_stack[0]*x2_stack[1] - x1_stack[1]*x2_stack[0],
183-
], dtype=res.dtype)
184-
assert_exactly_equal(res_stack, exact_cross)
164+
def exact_cross(a, b):
165+
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
166+
return asarray([
167+
a[1]*b[2] - a[2]*b[1],
168+
a[2]*b[0] - a[0]*b[2],
169+
a[0]*b[1] - a[1]*b[0],
170+
], dtype=res.dtype)
171+
172+
# We don't want to pass in **kw here because that would pass axis to
173+
# cross() on a single stack, but the axis is not meaningful on unstacked
174+
# vectors.
175+
_test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross)
185176

186177
@pytest.mark.xp_extension('linalg')
187178
@given(

0 commit comments

Comments
 (0)