@@ -32,14 +32,16 @@ def test_take(x, data):
32
32
33
33
out = xp .take (x , indices , axis = axis )
34
34
35
- ph .assert_dtype ("take" , x .dtype , out .dtype )
35
+ ph .assert_dtype ("take" , in_dtype = x .dtype , out_dtype = out .dtype )
36
36
ph .assert_shape (
37
37
"take" ,
38
- out .shape ,
39
- x .shape [:axis ] + (len (_indices ),) + x .shape [axis + 1 :],
40
- x = x ,
41
- indices = indices ,
42
- axis = axis ,
38
+ out_shape = out .shape ,
39
+ expected = x .shape [:axis ] + (len (_indices ),) + x .shape [axis + 1 :],
40
+ kw = dict (
41
+ x = x ,
42
+ indices = indices ,
43
+ axis = axis ,
44
+ ),
43
45
)
44
46
out_indices = sh .ndindex (out .shape )
45
47
axis_indices = list (sh .axis_ndindex (x .shape , axis ))
@@ -52,10 +54,10 @@ def test_take(x, data):
52
54
out_idx = next (out_indices )
53
55
ph .assert_0d_equals (
54
56
"take" ,
55
- sh .fmt_idx (f_take_idx , at_idx ),
56
- indexed_x [at_idx ],
57
- sh .fmt_idx ("out" , out_idx ),
58
- out [out_idx ],
57
+ x_repr = sh .fmt_idx (f_take_idx , at_idx ),
58
+ x_val = indexed_x [at_idx ],
59
+ out_repr = sh .fmt_idx ("out" , out_idx ),
60
+ out_val = out [out_idx ],
59
61
)
60
62
# sanity check
61
63
with pytest .raises (StopIteration ):
0 commit comments