-
Notifications
You must be signed in to change notification settings - Fork 34
sign
complex case implementations
#183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
sign for torch was already fixed at https://github.com/data-apis/array-api-compat/pull/137/files. I didn't realize cupy had the issue too. Do older versions of NumPy have this problem as well? |
Yes. Basically everything needs to be patched unless it is recent enough. (I know the torch error would not be present with import array_api_compat
print(array_api_compat.__version__) # 1.8
from array_api_compat import numpy as xp
print(xp.__version__) # 1.26.4
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # (1+0j)
import cupy as cp
print(cp.__version__) # 12.2.0
from array_api_compat import cupy as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # (1+0j)
import torch
print(torch.__version__) # 2.4.1+cu121
from array_api_compat import torch as xp
x = xp.asarray(1 + 2j)
# print(xp.sign(x)) # RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.
import dask
print(dask.__version__) # 2024.8.0
from array_api_compat.dask import array as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # dask.array<sign, shape=(), dtype=complex128, chunksize=(), chunktype=numpy.ndarray>
import array_api_strict as xp
print(xp.__version__) # 2.0.1
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # (1+0j)
import jax
print(jax.__version__) # 0.4.26
import jax.numpy as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # (0.44721365+0.8944273j) Basically |
Interesting. The test suite should be checking this as far as I can tell, but it hasn't come up, even though we do explicitly test against older versions of NumPy. That will require some investigation. |
So I dug into this and it looks like the test suite has been ignoring any exceptions raised in the reference implementations in the elementwise function tests. This appears to affect quite a few functions, although it isn't clear yet if there are any actual unwrapped incompatibilities due to this other than this |
For some reason, "except OverflowError" was changed to "except Exception" in e72184e. For now I have removed the except entirely, but it's possible we may need to keep the handling for OverflowError. There are several issues with tests that this was masking, which I have not fixed yet. Quite a few tests are not testing the complex implementations correctly because they are using math instead of cmath, for example. See also data-apis/array-api-compat#183.
Since the 2022.12 standard, the required implementation of
sign
has been:but I think only the most recent versions of libraries follow this (if any). Older versions of all libraries and even the most recent versions of some (e.g. CuPy, and even
array_api_strict
, which I can report separately if need be) use other conventions. It would be helpful if all libraries had aliases ofsign
that use the new definition.The text was updated successfully, but these errors were encountered: