@@ -354,14 +354,14 @@ def test_eye(n_rows, n_cols, kw):
354
354
ph .assert_kw_dtype ("eye" , kw_dtype = kw ["dtype" ], out_dtype = out .dtype )
355
355
_n_cols = n_rows if n_cols is None else n_cols
356
356
ph .assert_shape ("eye" , out_shape = out .shape , expected = (n_rows , _n_cols ), kw = dict (n_rows = n_rows , n_cols = n_cols ))
357
- f_func = f"[eye( { n_rows = } , { n_cols = } )]"
358
- for i in range ( n_rows ):
359
- for j in range (_n_cols ):
360
- f_indexed_out = f"out[ { i } , { j } ]= { out [ i , j ] } "
361
- if j - i == kw . get ( "k" , 0 ):
362
- assert out [ i , j ] == 1 , f" { f_indexed_out } , should be 1 { f_func } "
363
- else :
364
- assert out [ i , j ] == 0 , f" { f_indexed_out } , should be 0 { f_func } "
357
+ k = kw . get ( "k" , 0 )
358
+ expected = xp . asarray (
359
+ [[ 1 if j - i == k else 0 for j in range (_n_cols )] for i in range ( n_rows )],
360
+ dtype = out . dtype # Note: dtype already checked above.
361
+ )
362
+ if expected . size == 0 :
363
+ expected = xp . reshape ( expected , ( n_rows , _n_cols ))
364
+ ph . assert_array_elements ( "eye" , out = out , expected = expected , kw = kw )
365
365
366
366
367
367
default_unsafe_dtypes = [xp .uint64 ]
0 commit comments