|
16 | 16 | import pytest
|
17 | 17 | from hypothesis import assume, given
|
18 | 18 | from hypothesis.strategies import (booleans, composite, none, tuples, integers,
|
19 |
| - shared, sampled_from, data, just) |
| 19 | + shared, sampled_from, one_of, data, just) |
20 | 20 | from ndindex import iter_indices
|
21 | 21 |
|
22 | 22 | from .array_helpers import assert_exactly_equal, asarray
|
|
29 | 29 | SQRT_MAX_ARRAY_SIZE, finite_matrices)
|
30 | 30 | from . import dtype_helpers as dh
|
31 | 31 | from . import pytest_helpers as ph
|
32 |
| -from . import shape_helpers as sh |
33 | 32 |
|
34 | 33 | from . import _array_module
|
35 | 34 | from . import _array_module as xp
|
@@ -162,26 +161,18 @@ def test_cross(x1_x2_kw):
|
162 | 161 | assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
|
163 | 162 | assert res.shape == shape, "cross() did not return the correct shape"
|
164 | 163 |
|
165 |
| - # cross is too different from other functions to use _test_stacks, and it |
166 |
| - # is the only function that works the way it does, so it's not really |
167 |
| - # worth generalizing _test_stacks to handle it. |
168 |
| - a = axis if axis >= 0 else axis + len(shape) |
169 |
| - for _idx in sh.ndindex(shape[:a] + shape[a+1:]): |
170 |
| - idx = _idx[:a] + (slice(None),) + _idx[a:] |
171 |
| - assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite." |
172 |
| - res_stack = res[idx] |
173 |
| - x1_stack = x1[idx] |
174 |
| - x2_stack = x2[idx] |
175 |
| - assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." |
176 |
| - decomp_res_stack = linalg.cross(x1_stack, x2_stack) |
177 |
| - assert_exactly_equal(res_stack, decomp_res_stack) |
178 |
| - |
179 |
| - exact_cross = asarray([ |
180 |
| - x1_stack[1]*x2_stack[2] - x1_stack[2]*x2_stack[1], |
181 |
| - x1_stack[2]*x2_stack[0] - x1_stack[0]*x2_stack[2], |
182 |
| - x1_stack[0]*x2_stack[1] - x1_stack[1]*x2_stack[0], |
183 |
| - ], dtype=res.dtype) |
184 |
| - assert_exactly_equal(res_stack, exact_cross) |
| 164 | + def exact_cross(a, b): |
| 165 | + assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." |
| 166 | + return asarray([ |
| 167 | + a[1]*b[2] - a[2]*b[1], |
| 168 | + a[2]*b[0] - a[0]*b[2], |
| 169 | + a[0]*b[1] - a[1]*b[0], |
| 170 | + ], dtype=res.dtype) |
| 171 | + |
| 172 | + # We don't want to pass in **kw here because that would pass axis to |
| 173 | + # cross() on a single stack, but the axis is not meaningful on unstacked |
| 174 | + # vectors. |
| 175 | + _test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross) |
185 | 176 |
|
186 | 177 | @pytest.mark.xp_extension('linalg')
|
187 | 178 | @given(
|
|
0 commit comments