-
Notifications
You must be signed in to change notification settings - Fork 132
Fix JAX implementation of Argmax #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #809 +/- ##
==========================================
+ Coverage 80.97% 81.01% +0.03%
==========================================
Files 169 170 +1
Lines 47015 46924 -91
Branches 11497 11497
==========================================
- Hits 38072 38016 -56
+ Misses 6728 6700 -28
+ Partials 2215 2208 -7
|
tests/link/jax/test_nlinalg.py
Outdated
@@ -82,7 +82,7 @@ def assert_fn(x, y): | |||
|
|||
@pytest.mark.xfail( | |||
version_parse(jax.__version__) >= version_parse("0.2.12"), | |||
reason="Omnistaging cannot be disabled", | |||
reason="JAX Numpy API does not support dynamic shapes", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this test fail? Also seems like it should just work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking at it. It actually failed, but not due to dynamic shape error with omnistaging. For this one, when I removed the xfail
to see the underlying error, I got jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
instead. Checking here I see that the error is due to passing in a traced value in place of a Python integer. To get rid of it, we need to mark the value at the function level. That is another PR I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another problem I saw when I accidentally updated NumPy to 2.0 is many core APIs broke down. One example is np.obj2sctype(dtype)
used in pytensor/tensor/type.py:105
, which was removed in NumPy 2.0. This one actually made the test failed at collecting state. I raised an issue for this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm can you try to change x = dvector(shape=(2,))
and see if it works? Since that's not the concern of the test it's a shame to have it xfail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTensor is explicitly not compatible with numpy 2.0 at the moment: #688
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dvector()
has no shape
argument so I tried vector(dtype="float64", shape=(2,))
instead. The error persists. The stack trace pointed the error to start from this temporary function generated when we turn the Op
to JAX:
def jax_funcified_fgraph(tensor_variable):
# MaxAndArgmax{axis=(0,)}(<Vector(float64, shape=(2,))>)
tensor_variable_1, argmax = maxandargmax(tensor_variable)
# Mul(max, argmax)
tensor_variable_2 = elemwise_fn(tensor_variable_1, argmax)
return (tensor_variable_2,)
The solution from here suggests that we need to somehow mark tensor_variable
in the function above as static with functools.partial
and jax.jit
, but I think that it is out of scope of this PR? Maybe we can mark this as xfail for "Operation leads to JAX Tracer object is used in a context where a Python integer is expected (jax.errors.TracerIntegerConversionError)."
Co-authored-by: Ricardo Vieira <[email protected]>
…tensor into xfail_reason_update
Can you rebase, we got rid of MaxAndArgmax recently |
Co-authored-by: Ricardo Vieira <[email protected]>
…nto xfail_reason_update
Co-authored-by: Ricardo Vieira <[email protected]>
…nto xfail_reason_update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 Sadly the jax.errors.TracerIntegerConversionError: The index() method was called on traced array with shape int64[] error persists 🥲.
tests/link/jax/test_nlinalg.py
Outdated
@@ -82,7 +82,7 @@ def assert_fn(x, y): | |||
|
|||
@pytest.mark.xfail( | |||
version_parse(jax.__version__) >= version_parse("0.2.12"), | |||
reason="Omnistaging cannot be disabled", | |||
reason="Operation leads to JAX Tracer object is used in a context where a Python integer is expected (jax.errors.TracerIntegerConversionError).", | |||
) | |||
def test_jax_basic_multiout_omni(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rename the test to be just test max_and_argmax
, it is no longer a multioutput Operation, so the old name doesn't make sense. Also can you confirm this is the only test we had for Max / Argmax in JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's the only test for Max/Argmax we have.
Description
"Omnistaging can't be disabled" was added as reason for some expected failures due to breaking changes introduced in JAX 0.2.12 onwards. As omnistaging already became default in JAX for 4 years, more descriptive reasons are needed for the tests affected (3 in total).
Related Issue
Checklist
Type of change