diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 35ff1d42..f36d31a6 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -597,11 +597,14 @@ def solve_args(): of shape (..., M, M), and x2 is either shape (M,) or (..., M, K), where the ... parts of x1 and x2 are broadcast compatible. """ + mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dh.all_float_dtypes)) + stack_shapes = shared(two_mutually_broadcastable_shapes) # Don't worry about dtypes since all floating dtypes are type promotable # with each other. - x1 = shared(invertible_matrices(stack_shapes=stack_shapes.map(lambda pair: - pair[0]))) + x1 = shared(invertible_matrices( + stack_shapes=stack_shapes.map(lambda pair: pair[0]), + dtypes=mutual_dtypes.map(lambda pair: pair[0]))) @composite def _x2_shapes(draw): @@ -609,7 +612,7 @@ def _x2_shapes(draw): return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,) x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes()) - x2 = arrays(dtype=all_floating_dtypes(), shape=x2_shapes) + x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1])) return x1, x2 @pytest.mark.xp_extension('linalg') @@ -617,12 +620,19 @@ def _x2_shapes(draw): def test_solve(x1, x2): res = linalg.solve(x1, x2) + ph.assert_dtype("solve", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype) if x2.ndim == 1: + expected_shape = x1.shape[:-2] + x2.shape[-1:] _test_stacks(linalg.solve, x1, x2, res=res, dims=1, matrix_axes=[(-2, -1), (0,)], res_axes=[-1]) else: + stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + expected_shape = stack_shape + x2.shape[-2:] _test_stacks(linalg.solve, x1, x2, res=res, dims=2) + ph.assert_result_shape("solve", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=expected_shape) + @pytest.mark.xp_extension('linalg') @given( x=finite_matrices(),