|
24 | 24 | square_matrix_shapes, symmetric_matrices,
|
25 | 25 | positive_definite_matrices, MAX_ARRAY_SIZE,
|
26 | 26 | invertible_matrices, two_mutual_arrays,
|
27 |
| - mutually_promotable_dtypes, one_d_shapes) |
| 27 | + mutually_promotable_dtypes, one_d_shapes, |
| 28 | + two_mutually_broadcastable_shapes, SQRT_MAX_ARRAY_SIZE) |
28 | 29 | from .pytest_helpers import raises
|
29 | 30 |
|
30 | 31 | from .test_broadcasting import broadcast_shapes
|
@@ -440,11 +441,38 @@ def test_slogdet(x):
|
440 | 441 | # TODO: Test this when we have tests for floating-point values.
|
441 | 442 | # assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps)
|
442 | 443 |
|
443 |
| -@given( |
444 |
| - x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes), |
445 |
| - x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes), |
446 |
| -) |
| 444 | +def solve_args(): |
| 445 | + """ |
| 446 | + Strategy for the x1 and x2 arguments to test_solve() |
| 447 | +
|
| 448 | + solve() takes x1, x2, where x1 is any stack of square invertible matrices |
| 449 | + of shape (..., M, M), and x2 is either shape (..., M) or (..., M, K), |
| 450 | + where the ... parts of x1 and x2 are broadcast compatible. |
| 451 | + """ |
| 452 | + stack_shapes = shared(two_mutually_broadcastable_shapes) |
| 453 | + # Don't worry about dtypes since all floating dtypes are type promotable |
| 454 | + # with each other. |
| 455 | + x1 = shared(invertible_matrices(stack_shapes=stack_shapes.map(lambda pair: |
| 456 | + pair[0]))) |
| 457 | + |
| 458 | + @composite |
| 459 | + def x2_shapes(draw): |
| 460 | + end = draw(xps.array_shapes(min_dims=0, max_dims=1, min_side=0, |
| 461 | + max_side=SQRT_MAX_ARRAY_SIZE)) |
| 462 | + return draw(stack_shapes)[1] + draw(x1).shape[-1:] + end |
| 463 | + |
| 464 | + x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes()) |
| 465 | + return x1, x2 |
| 466 | + |
| 467 | +@given(*solve_args()) |
447 | 468 | def test_solve(x1, x2):
|
| 469 | + # TODO: solve() is currently ambiguous, in that some inputs can be |
| 470 | + # interpreted in two different ways. For example, if x1 is shape (2, 2, 2) |
| 471 | + # and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack |
| 472 | + # of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after |
| 473 | + # broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in |
| 474 | + # (2, 2, 2, 2). |
| 475 | + |
448 | 476 | # res = linalg.solve(x1, x2)
|
449 | 477 | pass
|
450 | 478 |
|
|
0 commit comments