Skip to content

Commit 0574111

Browse files
committed
Fix test_cross and test_outer
1 parent 035e3f3 commit 0574111

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

array_api_tests/test_linalg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def cross_args(draw, dtype_objects=dh.numeric_dtypes):
112112
kw = draw(kwargs(axis=integers(-size, size-1)))
113113
axis = kw.get('axis', -1)
114114
shape[axis] = 3
115+
shape = tuple(shape)
115116

116117
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtype_objects))
117118
arrays1 = xps.arrays(
@@ -139,7 +140,7 @@ def test_cross(x1_x2_kw):
139140

140141
res = linalg.cross(x1, x2, **kw)
141142

142-
assert res.dtype == dh.promotion_table[x1, x2], "cross() did not return the correct dtype"
143+
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
143144
assert res.shape == shape, "cross() did not return the correct shape"
144145

145146
# cross is too different from other functions to use _test_stacks, and it
@@ -365,7 +366,7 @@ def test_outer(x1, x2):
365366

366367
shape = (x1.shape[0], x2.shape[0])
367368
assert res.shape == shape, "outer() did not return the correct shape"
368-
assert res.dtype == dh.promotion_table[x1, x2], "outer() did not return the correct dtype"
369+
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "outer() did not return the correct dtype"
369370

370371
if 0 in shape:
371372
true_res = _array_module.empty(shape, dtype=res.dtype)

0 commit comments

Comments
 (0)