diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 35ff1d42..7e28d2b9 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -768,6 +768,7 @@ def _test_tensordot_stacks(x1, x2, kw, res): indices = [range(len(s))[i] for i in a] repl = dict(zip(sorted(indices), range(len(indices)))) res_axes.append(tuple(repl[i] for i in indices)) + res_axes = tuple(res_axes) for ((i,), (j,)), (res_idx,) in zip( itertools.product(