diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 3b17417d..dc2b69d8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -5,6 +5,8 @@ from typing import Literal, Optional, Tuple, Union from ._typing import ndarray +import math + import numpy as np if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple @@ -110,7 +112,7 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # on a single dimension. if axis is None: # Note: xp.linalg.norm() doesn't handle 0-D arrays - x = x.ravel() + _x = x.ravel() _axis = 0 elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas @@ -118,13 +120,14 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] normalized_axis = normalize_axis_tuple(axis, x.ndim) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest - x = xp.transpose(x, newshape).reshape( - (xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest])) + _x = xp.transpose(x, newshape).reshape( + (math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest])) _axis = 0 else: + _x = x _axis = axis - res = xp.linalg.norm(x, axis=_axis, ord=ord) + res = xp.linalg.norm(_x, axis=_axis, ord=ord) if keepdims: # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks diff --git a/test_cupy.sh b/test_cupy.sh index 1f832d2e..6b4d6b56 100755 --- a/test_cupy.sh +++ b/test_cupy.sh @@ -26,4 +26,4 @@ mkdir -p $SCRIPT_DIR/.hypothesis ln -s $SCRIPT_DIR/.hypothesis .hypothesis export ARRAY_API_TESTS_MODULE=array_api_compat.cupy -pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt --skips-file $SCRIPT_DIR/cupy-skips.txt "$@" +pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@"