Skip to content

Commit 5368f8e

Browse files
asmeurerhonno
authored andcommitted
Add a basic stacking test for matrix_norm
We only test finite matrices because the svd (ord=2) might raise an exception with infinite values.
1 parent eba1f35 commit 5368f8e

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

array_api_tests/test_linalg.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,14 +307,22 @@ def test_matmul(x1, x2):
307307
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
308308
_test_stacks(_array_module.matmul, x1, x2, res=res)
309309

310+
matrix_norm_shapes = shared(matrix_shapes())
311+
310312
@pytest.mark.xp_extension('linalg')
311313
@given(
312-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
313-
kw=kwargs(axis=todo, keepdims=todo, ord=todo)
314+
x=finite_matrices,
315+
kw=kwargs(keepdims=booleans(),
316+
ord=sampled_from([1, 2, float('inf'), -float('inf'), 'fro', 'nuc']))
314317
)
315318
def test_matrix_norm(x, kw):
316-
# res = linalg.matrix_norm(x, **kw)
317-
pass
319+
res = linalg.matrix_norm(x, **kw)
320+
321+
keepdims = kw.get('keepdims', False)
322+
ord = kw.get('ord', 'fro')
323+
324+
_test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0,
325+
res=res)
318326

319327
matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n')
320328
@pytest.mark.xp_extension('linalg')

0 commit comments

Comments
 (0)