Skip to content

ENH: add fallback_namespace #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _check_api_version(api_version):
if api_version is not None and api_version != '2021.12':
raise ValueError("Only the 2021.12 version of the array API specification is currently supported")

def array_namespace(*xs, api_version=None, _use_compat=True):
def array_namespace(*xs, api_version=None, _use_compat=True, fallback_namespace=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback namespace should still be an array API compatible namespace. So if someone does fallback_namespace=numpy instead of fallback_namespace=array_api_compat.numpy, we might want to automatically convert for them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some logic to try to convert the namespace. I did not see a function to do that but I suppose this could be made into it's own function if there are more needs.

"""
Get the array API compatible namespace for the arrays `xs`.

Expand All @@ -69,10 +69,23 @@ def your_function(x, y):
api_version should be the newest version of the spec that you need support
for (currently the compat library wrapped APIs only support v2021.12).
"""
# convert fallback_namespace
if fallback_namespace is not None:
try:
x_ = fallback_namespace.asarray(1)
fallback_namespace = array_namespace(
x_, _use_compat=_use_compat
)
except AttributeError as exc:
msg = "'fallback_namespace' must be an Array API compatible namespace"
raise TypeError(msg) from exc

namespaces = set()
for x in xs:
if isinstance(x, (tuple, list)):
namespaces.add(array_namespace(*x, _use_compat=_use_compat))
namespaces.add(array_namespace(
*x, _use_compat=_use_compat, fallback_namespace=fallback_namespace
))
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
elif _is_numpy_array(x):
Expand All @@ -99,6 +112,8 @@ def your_function(x, y):
else:
import torch
namespaces.add(torch)
elif fallback_namespace is not None:
namespaces.add(fallback_namespace)
else:
# TODO: Support Python scalars?
raise TypeError("The input is not a supported array type")
Expand Down
26 changes: 26 additions & 0 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,32 @@ def test_array_namespace_multiple():
assert array_namespace(x, x) == array_namespace((x, x)) == \
array_namespace((x, x), x) == array_api_compat.numpy

def test_fallback_namespace():
import numpy as np
import numpy.array_api
import array_api_compat.numpy

xp = array_api_compat.numpy
xp_ = array_namespace([1, 2], fallback_namespace=xp)
assert xp_ == xp

xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=xp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test to make sure the fallback namespace only applies when it's ambiguous. Something like array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=numpy.array_api) should return array_api_compat.numpy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how this would work. To me that's the purpose of the fallback. With

array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=numpy.array_api)

I would expect [1, 2] to fallback to numpy.array_api. I suppose that in this case we could say ok fallback [1, 2] to be the same as the other namespace. But what if instead we have:

array_namespace([1, 2], cp.asarray([1, 2]), fallback_namespace=numpy.array_api)

Shall this also simplify to cp??

(Due to the recursions it's also getting tricky to follow the logic.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing the point of this then. In general I'd say you shouldn't mix different array libraries with the array API. Some libraries might support it (like NumPy being able to convert other arrays), and in some cases with a performance penalty (like moving from GPU to CPU). So array_namespace([1, 2], cp.asarray([1, 2]), fallback_namespace=numpy.array_api) should give an error.

So to me the only point of this is to allow this function to also accept non-array inputs, like Python scalars and/or lists of Python scalars. You can convert those to arrays with asarray, but you can't call asarray until you have a namespace, making it a chicken/egg situation. Of course, then it might still be a good idea to make sure the errors from this function are helpful.

I also tend to agree that the best way to handle this is to simply not allow non-array inputs to functions that use the array API. That forces the users to resolve the ambiguity of what array library the want to use by calling asarray themselves.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, this function automatically denests lists of arrays. This was to internally support functions like concat which accept lists of arrays. But I'm wondering if this behavior should be removed, so that there there is less ambiguity and we can unconditionally error on list or tuple input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the main purpose of the proposal was to assume a namespace for things that would otherwise be marked as non compatible or unknown (list, scalar, tuples, else). Then yes it's on the user to call asarray and whether or not it works is on their side.

By the way, this function automatically denests lists of arrays.

Yeah this is making the logic difficult IMHO. I was having trouble to find a way to address your comment with that in place. If that can be removed, I am happy to do so (here or another PR, let me know.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the main purpose of the proposal was to assume a namespace for things that would otherwise be marked as non compatible or unknown (list, scalar, tuples, else).

I think that should live in SciPy, because it's orthogonal to array API support. Use of lists and other "array-likes" is specific to NumPy, and is by and large considered a design mistake. The array API standard does not allow array-likes, and neither do other libraries like PyTorch and CuPy. So there's no reason for this package to care about array-likes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that can be removed, I am happy to do so (here or another PR, let me know.)

If you can do that it would be great. I'm still unclear what the changes here are (do we actually want this or not?), so maybe a separate PR is appropriate. You'll need to fix the functions like concat that do accept a list. Right now the wrappers assume they can just pass things through to array_namespace.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should live in SciPy, because it's orthogonal to array API support. Use of lists and other "array-likes" is specific to NumPy, and is by and large considered a design mistake. The array API standard does not allow array-likes, and neither do other libraries like PyTorch and CuPy. So there's no reason for this package to care about array-likes.

I agree. For functions that are wrapped functions from NumPy (or any other library), we should support whatever the wrapped function supports to maintain maximal compatibility. But for new functions, we can be a little more strict. And generally speaking array_namespace() is an important function that will be used everywhere so we should try to discourage bad habits with it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then it looks like we could close this PR and I could try to propose a PR to only accept Array API compatible arrays. Is that correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that sounds good.

assert xp_ == xp

# convert to Array API compatible namespace
xp = array_api_compat.numpy
xp_ = array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace=np)
assert xp_ == xp

msg = 'Multiple namespaces'
with pytest.raises(TypeError, match=msg):
array_namespace([1, 2], numpy.array_api.asarray([1, 2]), fallback_namespace=np)

msg = "'fallback_namespace' must be an Array API"
with pytest.raises(TypeError, match=msg):
array_namespace([1, 2], np.asarray([1, 2]), fallback_namespace="hop")


def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
pytest.raises(TypeError, lambda: array_namespace())
Expand Down