@@ -126,8 +126,9 @@ def _test_namedtuple(res, fields, func_name):
126
126
def test_cholesky (x , kw ):
127
127
res = linalg .cholesky (x , ** kw )
128
128
129
- assert res .shape == x .shape , "cholesky() did not return the correct shape"
130
- assert res .dtype == x .dtype , "cholesky() did not return the correct dtype"
129
+ ph .assert_dtype ("cholesky" , in_dtype = x .dtype , out_dtype = res .dtype )
130
+ ph .assert_result_shape ("cholesky" , in_shapes = [x .shape ],
131
+ out_shape = res .shape , expected = x .shape )
131
132
132
133
_test_stacks (linalg .cholesky , x , ** kw , res = res )
133
134
@@ -192,7 +193,7 @@ def test_cross(x1_x2_kw):
192
193
193
194
ph .assert_dtype ("cross" , in_dtype = [x1 .dtype , x2 .dtype ],
194
195
out_dtype = res .dtype )
195
- ph .assert_shape ("cross" , out_shape = res .shape , expected = broadcasted_shape )
196
+ ph .assert_result_shape ("cross" , in_shapes = [ x1 . shape , x2 . shape ] , out_shape = res .shape , expected = broadcasted_shape )
196
197
197
198
def exact_cross (a , b ):
198
199
assert a .shape == b .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
@@ -214,8 +215,9 @@ def exact_cross(a, b):
214
215
def test_det (x ):
215
216
res = linalg .det (x )
216
217
217
- assert res .dtype == x .dtype , "det() did not return the correct dtype"
218
- assert res .shape == x .shape [:- 2 ], "det() did not return the correct shape"
218
+ ph .assert_dtype ("det" , in_dtype = x .dtype , out_dtype = res .dtype )
219
+ ph .assert_result_shape ("det" , in_shapes = [x .shape ], out_shape = res .shape ,
220
+ expected = x .shape [:- 2 ])
219
221
220
222
_test_stacks (linalg .det , x , res = res , dims = 0 )
221
223
@@ -231,7 +233,7 @@ def test_det(x):
231
233
def test_diagonal (x , kw ):
232
234
res = linalg .diagonal (x , ** kw )
233
235
234
- assert res . dtype == x .dtype , "diagonal() returned the wrong dtype"
236
+ ph . assert_dtype ( "diagonal" , in_dtype = x .dtype , out_dtype = res . dtype )
235
237
236
238
n , m = x .shape [- 2 :]
237
239
offset = kw .get ('offset' , 0 )
@@ -245,7 +247,9 @@ def test_diagonal(x, kw):
245
247
else :
246
248
diag_size = min (n , m , max (m - offset , 0 ))
247
249
248
- assert res .shape == (* x .shape [:- 2 ], diag_size ), "diagonal() returned the wrong shape"
250
+ expected_shape = (* x .shape [:- 2 ], diag_size )
251
+ ph .assert_result_shape ("diagonal" , in_shapes = [x .shape ],
252
+ out_shape = res .shape , expected = expected_shape )
249
253
250
254
def true_diag (x_stack , offset = 0 ):
251
255
if offset >= 0 :
@@ -266,11 +270,18 @@ def test_eigh(x):
266
270
eigenvalues = res .eigenvalues
267
271
eigenvectors = res .eigenvectors
268
272
269
- assert eigenvalues .dtype == x .dtype , "eigh().eigenvalues did not return the correct dtype"
270
- assert eigenvalues .shape == x .shape [:- 1 ], "eigh().eigenvalues did not return the correct shape"
273
+ ph .assert_dtype ("eigh" , in_dtype = x .dtype , out_dtype = eigenvalues .dtype ,
274
+ expected = x .dtype , repr_name = "eigenvalues.dtype" )
275
+ ph .assert_result_shape ("eigh" , in_shapes = [x .shape ],
276
+ out_shape = eigenvalues .shape ,
277
+ expected = x .shape [:- 1 ],
278
+ repr_name = "eigenvalues.shape" )
271
279
272
- assert eigenvectors .dtype == x .dtype , "eigh().eigenvectors did not return the correct dtype"
273
- assert eigenvectors .shape == x .shape , "eigh().eigenvectors did not return the correct shape"
280
+ ph .assert_dtype ("eigh" , in_dtype = x .dtype , out_dtype = eigenvectors .dtype ,
281
+ expected = x .dtype , repr_name = "eigenvectors.dtype" )
282
+ ph .assert_result_shape ("eigh" , in_shapes = [x .shape ],
283
+ out_shape = eigenvectors .shape , expected = x .shape ,
284
+ repr_name = "eigenvectors.shape" )
274
285
275
286
# Note: _test_stacks here is only testing the shape and dtype. The actual
276
287
# eigenvalues and eigenvectors may not be equal at all, since there is not
@@ -292,8 +303,9 @@ def test_eigh(x):
292
303
def test_eigvalsh (x ):
293
304
res = linalg .eigvalsh (x )
294
305
295
- assert res .dtype == x .dtype , "eigvalsh() did not return the correct dtype"
296
- assert res .shape == x .shape [:- 1 ], "eigvalsh() did not return the correct shape"
306
+ ph .assert_dtype ("eigvalsh" , in_dtype = x .dtype , out_dtype = res .dtype )
307
+ ph .assert_result_shape ("eigvalsh" , in_shapes = [x .shape ],
308
+ out_shape = res .shape , expected = x .shape [:- 1 ])
297
309
298
310
# Note: _test_stacks here is only testing the shape and dtype. The actual
299
311
# eigenvalues may not be equal at all, since there is not requirements or
@@ -311,8 +323,9 @@ def test_eigvalsh(x):
311
323
def test_inv (x ):
312
324
res = linalg .inv (x )
313
325
314
- assert res .shape == x .shape , "inv() did not return the correct shape"
315
- assert res .dtype == x .dtype , "inv() did not return the correct dtype"
326
+ ph .assert_dtype ("inv" , in_dtype = x .dtype , out_dtype = res .dtype )
327
+ ph .assert_result_shape ("inv" , in_shapes = [x .shape ], out_shape = res .shape ,
328
+ expected = x .shape )
316
329
317
330
_test_stacks (linalg .inv , x , res = res )
318
331
@@ -339,18 +352,24 @@ def test_matmul(x1, x2):
339
352
ph .assert_dtype ("matmul" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = res .dtype )
340
353
341
354
if len (x1 .shape ) == len (x2 .shape ) == 1 :
342
- assert res .shape == ()
355
+ ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
356
+ out_shape = res .shape , expected = ())
343
357
elif len (x1 .shape ) == 1 :
344
- assert res .shape == x2 .shape [:- 2 ] + x2 .shape [- 1 :]
358
+ ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
359
+ out_shape = res .shape ,
360
+ expected = x2 .shape [:- 2 ] + x2 .shape [- 1 :])
345
361
_test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 ,
346
362
matrix_axes = [(0 ,), (- 2 , - 1 )], res_axes = [- 1 ])
347
363
elif len (x2 .shape ) == 1 :
348
- assert res .shape == x1 .shape [:- 1 ]
364
+ ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
365
+ out_shape = res .shape , expected = x1 .shape [:- 1 ])
349
366
_test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 ,
350
367
matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
351
368
else :
352
369
stack_shape = sh .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
353
- assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
370
+ ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
371
+ out_shape = res .shape ,
372
+ expected = stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ]))
354
373
_test_stacks (_array_module .matmul , x1 , x2 , res = res )
355
374
356
375
@pytest .mark .xp_extension ('linalg' )
@@ -370,8 +389,9 @@ def test_matrix_norm(x, kw):
370
389
expected_shape = x .shape [:- 2 ] + (1 , 1 )
371
390
else :
372
391
expected_shape = x .shape [:- 2 ]
373
- assert res .shape == expected_shape , f"matrix_norm({ keepdims = } ) did not return the correct shape"
374
- assert res .dtype == x .dtype , "matrix_norm() did not return the correct dtype"
392
+ ph .assert_dtype ("matrix_norm" , in_dtype = x .dtype , out_dtype = res .dtype )
393
+ ph .assert_result_shape ("matrix_norm" , in_shapes = [x .shape ],
394
+ out_shape = res .shape , expected = expected_shape )
375
395
376
396
_test_stacks (linalg .matrix_norm , x , ** kw , dims = 2 if keepdims else 0 ,
377
397
res = res )
@@ -388,8 +408,9 @@ def test_matrix_norm(x, kw):
388
408
def test_matrix_power (x , n ):
389
409
res = linalg .matrix_power (x , n )
390
410
391
- assert res .shape == x .shape , "matrix_power() did not return the correct shape"
392
- assert res .dtype == x .dtype , "matrix_power() did not return the correct dtype"
411
+ ph .assert_dtype ("matrix_power" , in_dtype = x .dtype , out_dtype = res .dtype )
412
+ ph .assert_result_shape ("matrix_power" , in_shapes = [x .shape ],
413
+ out_shape = res .shape , expected = x .shape )
393
414
394
415
if n == 0 :
395
416
true_val = lambda x : _array_module .eye (x .shape [0 ], dtype = x .dtype )
@@ -419,8 +440,9 @@ def test_matrix_transpose(x):
419
440
shape = list (x .shape )
420
441
shape [- 1 ], shape [- 2 ] = shape [- 2 ], shape [- 1 ]
421
442
shape = tuple (shape )
422
- assert res .shape == shape , "matrix_transpose() did not return the correct shape"
423
- assert res .dtype == x .dtype , "matrix_transpose() did not return the correct dtype"
443
+ ph .assert_dtype ("matrix_transpose" , in_dtype = x .dtype , out_dtype = res .dtype )
444
+ ph .assert_result_shape ("matrix_transpose" , in_shapes = [x .shape ],
445
+ out_shape = res .shape , expected = shape )
424
446
425
447
_test_stacks (_array_module .matrix_transpose , x , res = res , true_val = true_val )
426
448
@@ -435,8 +457,9 @@ def test_outer(x1, x2):
435
457
res = linalg .outer (x1 , x2 )
436
458
437
459
shape = (x1 .shape [0 ], x2 .shape [0 ])
438
- assert res .shape == shape , "outer() did not return the correct shape"
439
- assert res .dtype == dh .result_type (x1 .dtype , x2 .dtype ), "outer() did not return the correct dtype"
460
+ ph .assert_dtype ("outer" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = res .dtype )
461
+ ph .assert_result_shape ("outer" , in_shapes = [x1 .shape , x2 .shape ],
462
+ out_shape = res .shape , expected = shape )
440
463
441
464
if 0 in shape :
442
465
true_res = _array_module .empty (shape , dtype = res .dtype )
@@ -472,17 +495,23 @@ def test_qr(x, kw):
472
495
Q = res .Q
473
496
R = res .R
474
497
475
- assert Q .dtype == x .dtype , "qr().Q did not return the correct dtype"
498
+ ph .assert_dtype ("qr" , in_dtype = x .dtype , out_dtype = Q .dtype ,
499
+ expected = x .dtype , repr_name = "Q.dtype" )
476
500
if mode == 'complete' :
477
- assert Q . shape == x .shape [:- 2 ] + (M , M ), "qr().Q did not return the correct shape"
501
+ expected_Q_shape = x .shape [:- 2 ] + (M , M )
478
502
else :
479
- assert Q .shape == x .shape [:- 2 ] + (M , K ), "qr().Q did not return the correct shape"
503
+ expected_Q_shape = x .shape [:- 2 ] + (M , K )
504
+ ph .assert_result_shape ("qr" , in_shapes = [x .shape ], out_shape = Q .shape ,
505
+ expected = expected_Q_shape , repr_name = "Q.shape" )
480
506
481
- assert R .dtype == x .dtype , "qr().R did not return the correct dtype"
507
+ ph .assert_dtype ("qr" , in_dtype = x .dtype , out_dtype = R .dtype ,
508
+ expected = x .dtype , repr_name = "R.dtype" )
482
509
if mode == 'complete' :
483
- assert R . shape == x .shape [:- 2 ] + (M , N ), "qr().R did not return the correct shape"
510
+ expected_R_shape = x .shape [:- 2 ] + (M , N )
484
511
else :
485
- assert R .shape == x .shape [:- 2 ] + (K , N ), "qr().R did not return the correct shape"
512
+ expected_R_shape = x .shape [:- 2 ] + (K , N )
513
+ ph .assert_result_shape ("qr" , in_shapes = [x .shape ], out_shape = R .shape ,
514
+ expected = expected_R_shape , repr_name = "R.shape" )
486
515
487
516
_test_stacks (lambda x : linalg .qr (x , ** kw ).Q , x , res = Q )
488
517
_test_stacks (lambda x : linalg .qr (x , ** kw ).R , x , res = R )
@@ -505,14 +534,17 @@ def test_slogdet(x):
505
534
506
535
ph .assert_dtype ("slogdet" , in_dtype = x .dtype , out_dtype = sign .dtype ,
507
536
expected = x .dtype , repr_name = "sign.dtype" )
508
- ph .assert_shape ("slogdet" , out_shape = sign .shape , expected = x .shape [:- 2 ],
509
- repr_name = "sign.shape" )
537
+ ph .assert_result_shape ("slogdet" , in_shapes = [x .shape ],
538
+ out_shape = sign .shape ,
539
+ expected = x .shape [:- 2 ],
540
+ repr_name = "sign.shape" )
510
541
expected_dtype = dh .as_real_dtype (x .dtype )
511
542
ph .assert_dtype ("slogdet" , in_dtype = x .dtype , out_dtype = logabsdet .dtype ,
512
543
expected = expected_dtype , repr_name = "logabsdet.dtype" )
513
- ph .assert_shape ("slogdet" , out_shape = logabsdet .shape ,
514
- expected = x .shape [:- 2 ],
515
- repr_name = "logabsdet.shape" )
544
+ ph .assert_result_shape ("slogdet" , in_shapes = [x .shape ],
545
+ out_shape = logabsdet .shape ,
546
+ expected = x .shape [:- 2 ],
547
+ repr_name = "logabsdet.shape" )
516
548
517
549
_test_stacks (lambda x : linalg .slogdet (x ).sign , x ,
518
550
res = sign , dims = 0 )
@@ -584,17 +616,31 @@ def test_svd(x, kw):
584
616
585
617
U , S , Vh = res
586
618
587
- assert U .dtype == x .dtype , "svd().U did not return the correct dtype"
588
- assert S .dtype == x .dtype , "svd().S did not return the correct dtype"
589
- assert Vh .dtype == x .dtype , "svd().Vh did not return the correct dtype"
619
+ ph .assert_dtype ("svd" , in_dtype = x .dtype , out_dtype = U .dtype ,
620
+ expected = x .dtype , repr_name = "U.dtype" )
621
+ ph .assert_dtype ("svd" , in_dtype = x .dtype , out_dtype = S .dtype ,
622
+ expected = x .dtype , repr_name = "S.dtype" )
623
+ ph .assert_dtype ("svd" , in_dtype = x .dtype , out_dtype = Vh .dtype ,
624
+ expected = x .dtype , repr_name = "Vh.dtype" )
590
625
591
626
if full_matrices :
592
- assert U . shape == (* stack , M , M ), "svd().U did not return the correct shape"
593
- assert Vh . shape == (* stack , N , N ), "svd().Vh did not return the correct shape"
627
+ expected_U_shape = (* stack , M , M )
628
+ expected_Vh_shape = (* stack , N , N )
594
629
else :
595
- assert U .shape == (* stack , M , K ), "svd(full_matrices=False).U did not return the correct shape"
596
- assert Vh .shape == (* stack , K , N ), "svd(full_matrices=False).Vh did not return the correct shape"
597
- assert S .shape == (* stack , K ), "svd().S did not return the correct shape"
630
+ expected_U_shape = (* stack , M , K )
631
+ expected_Vh_shape = (* stack , K , N )
632
+ ph .assert_result_shape ("svd" , in_shapes = [x .shape ],
633
+ out_shape = U .shape ,
634
+ expected = expected_U_shape ,
635
+ repr_name = "U.shape" )
636
+ ph .assert_result_shape ("svd" , in_shapes = [x .shape ],
637
+ out_shape = Vh .shape ,
638
+ expected = expected_Vh_shape ,
639
+ repr_name = "Vh.shape" )
640
+ ph .assert_result_shape ("svd" , in_shapes = [x .shape ],
641
+ out_shape = S .shape ,
642
+ expected = (* stack , K ),
643
+ repr_name = "S.shape" )
598
644
599
645
# The values of s must be sorted from largest to smallest
600
646
if K >= 1 :
@@ -614,8 +660,11 @@ def test_svdvals(x):
614
660
* stack , M , N = x .shape
615
661
K = min (M , N )
616
662
617
- assert res .dtype == x .dtype , "svdvals() did not return the correct dtype"
618
- assert res .shape == (* stack , K ), "svdvals() did not return the correct shape"
663
+ ph .assert_dtype ("svdvals" , in_dtype = x .dtype , out_dtype = res .dtype ,
664
+ expected = x .dtype )
665
+ ph .assert_result_shape ("svdvals" , in_shapes = [x .shape ],
666
+ out_shape = res .shape ,
667
+ expected = (* stack , K ))
619
668
620
669
# SVD values must be sorted from largest to smallest
621
670
assert _array_module .all (res [..., :- 1 ] >= res [..., 1 :]), "svdvals() values are not sorted from largest to smallest"
@@ -753,7 +802,7 @@ def test_trace(x, kw):
753
802
# assert res.dtype == x.dtype, "trace() returned the wrong dtype"
754
803
755
804
n , m = x .shape [- 2 :]
756
- assert res .shape == x .shape [:- 2 ], "trace() returned the wrong shape"
805
+ ph . assert_result_shape ( 'trace' , x . shape , res .shape , expected = x .shape [:- 2 ])
757
806
758
807
def true_trace (x_stack , offset = 0 ):
759
808
# Note: the spec does not specify that offset must be within the
@@ -799,7 +848,8 @@ def test_vecdot(x1, x2, data):
799
848
800
849
ph .assert_dtype ("vecdot" , in_dtype = [x1 .dtype , x2 .dtype ],
801
850
out_dtype = res .dtype )
802
- ph .assert_shape ("vecdot" , out_shape = res .shape , expected = expected_shape )
851
+ ph .assert_result_shape ("vecdot" , in_shapes = [x1 .shape , x2 .shape ],
852
+ out_shape = res .shape , expected = expected_shape )
803
853
804
854
if x1 .dtype in dh .int_dtypes :
805
855
def true_val (x , y , axis = - 1 ):
0 commit comments