Skip to content

Commit 8706391

Browse files
committed
Add WIP strategy for test_solve()
linalg.solve is presently ambiguous in its definition. See data-apis/array-api#285. So for now we are not implementing the test, as it is not possible to make a test that works.
1 parent 2038198 commit 8706391

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

array_api_tests/test_linalg.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
square_matrix_shapes, symmetric_matrices,
2525
positive_definite_matrices, MAX_ARRAY_SIZE,
2626
invertible_matrices, two_mutual_arrays,
27-
mutually_promotable_dtypes, one_d_shapes)
27+
mutually_promotable_dtypes, one_d_shapes,
28+
two_mutually_broadcastable_shapes, SQRT_MAX_ARRAY_SIZE)
2829
from .pytest_helpers import raises
2930

3031
from .test_broadcasting import broadcast_shapes
@@ -440,11 +441,38 @@ def test_slogdet(x):
440441
# TODO: Test this when we have tests for floating-point values.
441442
# assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps)
442443

443-
@given(
444-
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
445-
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
446-
)
444+
def solve_args():
445+
"""
446+
Strategy for the x1 and x2 arguments to test_solve()
447+
448+
solve() takes x1, x2, where x1 is any stack of square invertible matrices
449+
of shape (..., M, M), and x2 is either shape (..., M) or (..., M, K),
450+
where the ... parts of x1 and x2 are broadcast compatible.
451+
"""
452+
stack_shapes = shared(two_mutually_broadcastable_shapes)
453+
# Don't worry about dtypes since all floating dtypes are type promotable
454+
# with each other.
455+
x1 = shared(invertible_matrices(stack_shapes=stack_shapes.map(lambda pair:
456+
pair[0])))
457+
458+
@composite
459+
def x2_shapes(draw):
460+
end = draw(xps.array_shapes(min_dims=0, max_dims=1, min_side=0,
461+
max_side=SQRT_MAX_ARRAY_SIZE))
462+
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + end
463+
464+
x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes())
465+
return x1, x2
466+
467+
@given(*solve_args())
447468
def test_solve(x1, x2):
469+
# TODO: solve() is currently ambiguous, in that some inputs can be
470+
# interpreted in two different ways. For example, if x1 is shape (2, 2, 2)
471+
# and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack
472+
# of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after
473+
# broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in
474+
# (2, 2, 2, 2).
475+
448476
# res = linalg.solve(x1, x2)
449477
pass
450478

0 commit comments

Comments
 (0)