How to discover an array type's namespace from an input type annotation #948
Unanswered
nicholasjng
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
In machine learning workflow engines, I've come across a lot of code similar to this (with JAX as the array API implementer of choice in this example):
where the array I/O needed for the output of
zeros
and the input ofcompute_sum
is abstracted away from the user in the@task
decorator.A lot of workflow libraries allow the user to customize said I/O, perhaps with an interface like the following:
which would be called to instantiate all of the arrays that come out of
zeros
and go intocompute_sum
, respectively. These happen to be the same in this example.I think the array API standard would be a good fit to generalize such an I/O machinery to handle any array API standard implementing type, since it could provide a very basic load mechanism for arrays like this:
Unfortunately, it seems that
__array_namespace__
is an instance method only, so I cannot call it on the array type itself. Indeed, in the example of JAX, I get the following:NumPy produces a different error, but I'm also unable to obtain the namespace there. As an added difficulty, JAX's array namespace is
jax.numpy
, butjax.Array.__module__ == 'jax'
.Is there any standards-conforming way I can deduce the array namespace from an array type annotation only?
Beta Was this translation helpful? Give feedback.
All reactions