diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 91c4d9a7..7a90f444 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple import inspect -from ._helpers import array_namespace, _check_device, device, is_torch_array +from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace # These functions are modified from the NumPy versions. @@ -530,6 +530,21 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) +# numpy 1.26 does not use the standard definition for sign on complex numbers + +def sign(x: ndarray, /, xp, **kwargs) -> ndarray: + if isdtype(x.dtype, 'complex floating', xp=xp): + out = (x/xp.abs(x, **kwargs))[...] + # sign(0) = 0 but the above formula would give nan + out[x == 0+0j] = 0+0j + else: + out = xp.sign(x, **kwargs) + # CuPy sign() does not propagate nans. See + # https://github.com/data-apis/array-api-compat/issues/136 + if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + out[xp.isnan(x)] = xp.nan + return out[()] + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', @@ -537,4 +552,4 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack'] + 'unstack', 'sign'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30ae2943..3627fb6b 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -62,6 +62,7 @@ matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) +sign = get_xp(cp)(_aliases.sign) _copy_default = object() @@ -109,13 +110,6 @@ def asarray( return cp.array(obj, dtype=dtype, **kwargs) -def sign(x: ndarray, /) -> ndarray: - # CuPy sign() does not propagate nans. See - # https://github.com/data-apis/array-api-compat/issues/136 - out = cp.sign(x) - out[cp.isnan(x)] = cp.nan - return out - # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index a24694f3..ee2d88c0 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -104,7 +104,7 @@ def _dask_arange( trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) - +sign = get_xp(np)(_aliases.sign) # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 355215e4..2bfc98ff 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -62,6 +62,7 @@ matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) +sign = get_xp(np)(_aliases.sign) def _supports_buffer_protocol(obj): try: diff --git a/torch-xfails.txt b/torch-xfails.txt index c7abe2e9..c972659e 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -56,6 +56,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1 array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] +# inverse trig functions are too inaccurate on CPU +array_api_tests/test_operators_and_elementwise_functions.py::test_acos +array_api_tests/test_operators_and_elementwise_functions.py::test_atan +array_api_tests/test_operators_and_elementwise_functions.py::test_asin # overflow near float max array_api_tests/test_operators_and_elementwise_functions.py::test_log1p