From 54840418ebb99b14f026b68122c15449e189b4ce Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 27 Feb 2024 15:39:38 -0700 Subject: [PATCH] Avoid passing a list of axes to tensordot cupy has logic that fails if axis is not a tuple. The standard type annotations only require axes to be a tuple. --- array_api_tests/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 35ff1d42..7e28d2b9 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -768,6 +768,7 @@ def _test_tensordot_stacks(x1, x2, kw, res): indices = [range(len(s))[i] for i in a] repl = dict(zip(sorted(indices), range(len(indices)))) res_axes.append(tuple(repl[i] for i in indices)) + res_axes = tuple(res_axes) for ((i,), (j,)), (res_idx,) in zip( itertools.product(