-
Notifications
You must be signed in to change notification settings - Fork 132
Fix JAX implementation of Argmax #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
71fe3b9
a170d33
d0fdd6e
63b0832
68e245f
3679791
ae0801d
f287d6c
8b05952
88b17b3
349beef
8ed6ab2
b50a906
89c7bff
2c4a378
c3c5fd0
be9c9f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,7 +82,7 @@ def assert_fn(x, y): | |
|
||
@pytest.mark.xfail( | ||
version_parse(jax.__version__) >= version_parse("0.2.12"), | ||
reason="Omnistaging cannot be disabled", | ||
reason="JAX Numpy API does not support dynamic shapes", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this test fail? Also seems like it should just work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was looking at it. It actually failed, but not due to dynamic shape error with omnistaging. For this one, when I removed the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another problem I saw when I accidentally updated NumPy to 2.0 is many core APIs broke down. One example is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm can you try to change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PyTensor is explicitly not compatible with numpy 2.0 at the moment: #688 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
def jax_funcified_fgraph(tensor_variable):
# MaxAndArgmax{axis=(0,)}(<Vector(float64, shape=(2,))>)
tensor_variable_1, argmax = maxandargmax(tensor_variable)
# Mul(max, argmax)
tensor_variable_2 = elemwise_fn(tensor_variable_1, argmax)
return (tensor_variable_2,) The solution from here suggests that we need to somehow mark |
||
) | ||
def test_jax_basic_multiout_omni(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should rename the test to be just test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that's the only test for Max/Argmax we have. |
||
# Test that a single output of a multi-output `Op` can be used as input to | ||
|
@@ -96,7 +96,7 @@ def test_jax_basic_multiout_omni(): | |
|
||
@pytest.mark.xfail( | ||
version_parse(jax.__version__) >= version_parse("0.2.12"), | ||
reason="Omnistaging cannot be disabled", | ||
reason="`dot` -> `Gemv` optimization is incompatible with JAX", | ||
HangenYuu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
def test_tensor_basics(): | ||
y = vector("y") | ||
|
Uh oh!
There was an error while loading. Please reload this page.