Skip to content

Commit 1a73804

Browse files
authored
Merge pull request #181 from asmeurer/torch-fixes
Fix some issues
2 parents e0bb425 + 8a3d1bf commit 1a73804

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

array_api_tests/test_array_object.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:
250250
@pytest.mark.parametrize(
251251
"method_name, dtype, stype",
252252
[make_param("__bool__", xp.bool, bool)]
253-
+ [make_param("__int__", d, int) for d in dh.all_int_dtypes]
254-
+ [make_param("__index__", d, int) for d in dh.all_int_dtypes]
253+
+ [make_param("__int__", d, int) for d in dh._filter_stubs(*dh.all_int_dtypes)]
254+
+ [make_param("__index__", d, int) for d in dh._filter_stubs(*dh.all_int_dtypes)]
255255
+ [make_param("__float__", d, float) for d in dh.float_dtypes],
256256
)
257257
@given(data=st.data())

array_api_tests/test_data_type_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_finfo(dtype):
164164
# TODO: test values
165165

166166

167-
@pytest.mark.parametrize("dtype", dh.all_int_dtypes, ids=make_dtype_id)
167+
@pytest.mark.parametrize("dtype", dh._filter_stubs(*dh.all_int_dtypes), ids=make_dtype_id)
168168
def test_iinfo(dtype):
169169
out = xp.iinfo(dtype)
170170
f_func = f"[iinfo({dh.dtype_to_name[dtype]})]"

array_api_tests/test_indexing_functions.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@ def test_take(x, data):
3232

3333
out = xp.take(x, indices, axis=axis)
3434

35-
ph.assert_dtype("take", x.dtype, out.dtype)
35+
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
3636
ph.assert_shape(
3737
"take",
38-
out.shape,
39-
x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
40-
x=x,
41-
indices=indices,
42-
axis=axis,
38+
out_shape=out.shape,
39+
expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
40+
kw=dict(
41+
x=x,
42+
indices=indices,
43+
axis=axis,
44+
),
4345
)
4446
out_indices = sh.ndindex(out.shape)
4547
axis_indices = list(sh.axis_ndindex(x.shape, axis))
@@ -52,10 +54,10 @@ def test_take(x, data):
5254
out_idx = next(out_indices)
5355
ph.assert_0d_equals(
5456
"take",
55-
sh.fmt_idx(f_take_idx, at_idx),
56-
indexed_x[at_idx],
57-
sh.fmt_idx("out", out_idx),
58-
out[out_idx],
57+
x_repr=sh.fmt_idx(f_take_idx, at_idx),
58+
x_val=indexed_x[at_idx],
59+
out_repr=sh.fmt_idx("out", out_idx),
60+
out_val=out[out_idx],
5961
)
6062
# sanity check
6163
with pytest.raises(StopIteration):

0 commit comments

Comments
 (0)