From 4b942416cb4ed32c61a3fdf7e9ddd4ed20857348 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 27 Feb 2024 13:37:26 -0700 Subject: [PATCH 1/3] Fix numpy vector_norm(keepdims=True) --- array_api_compat/common/_linalg.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 3b17417d..58fe5491 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -110,7 +110,7 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # on a single dimension. if axis is None: # Note: xp.linalg.norm() doesn't handle 0-D arrays - x = x.ravel() + _x = x.ravel() _axis = 0 elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas @@ -118,13 +118,14 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] normalized_axis = normalize_axis_tuple(axis, x.ndim) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest - x = xp.transpose(x, newshape).reshape( + _x = xp.transpose(x, newshape).reshape( (xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest])) _axis = 0 else: + _x = x _axis = axis - res = xp.linalg.norm(x, axis=_axis, ord=ord) + res = xp.linalg.norm(_x, axis=_axis, ord=ord) if keepdims: # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks From f4488758b201e64dc77d48a414ceed426990e9c1 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 27 Feb 2024 15:35:32 -0700 Subject: [PATCH 2/3] Use math.prod to multiply integers --- array_api_compat/common/_linalg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 58fe5491..dc2b69d8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -5,6 +5,8 @@ from typing import Literal, Optional, Tuple, Union from ._typing import ndarray +import math + import numpy as np if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple @@ -119,7 +121,7 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( - (xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest])) + (math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest])) _axis = 0 else: _x = x From 0837875faba0eb032515dddfa1f01730b7a61ac9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 27 Feb 2024 15:42:43 -0700 Subject: [PATCH 3/3] Remove reference to deleted cupy-skips.txt --- test_cupy.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_cupy.sh b/test_cupy.sh index 1f832d2e..6b4d6b56 100755 --- a/test_cupy.sh +++ b/test_cupy.sh @@ -26,4 +26,4 @@ mkdir -p $SCRIPT_DIR/.hypothesis ln -s $SCRIPT_DIR/.hypothesis .hypothesis export ARRAY_API_TESTS_MODULE=array_api_compat.cupy -pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt --skips-file $SCRIPT_DIR/cupy-skips.txt "$@" +pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@"