Skip to content

Fix JAX warnings in tests #307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2024
Merged

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Feb 11, 2024

Closes #305

JAX warning mentioned #305 seems to be caused by using multiple cores when sampling with pm.sample in tests. The JAX test appears to be over-eager, because it is issued even when JAX is not being used, as long as JAX has been imported.

This PR adds the warning to filterwarnings in the pyproject.toml, since it doesn't appear to be relevant to the tests its causing to fail.

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Feb 12, 2024

I wanted to filter by module (jax._src.xla_bridge), but it didn't work. This does, so I'm going with it.

By far my worst work, but it fixes it issue.

I'd still like to move tests up one level, out of the project files.

@ricardoV94
Copy link
Member

I'd still like to move tests up one level, out of the project files.

Let's do that in a separate PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

test_histogram_approximation failing due to warning in newer JAX release
2 participants