Skip to content

Fix bug in fit_MAP when shared variables are used in graph #468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 2, 2025

Conversation

jessegrabowski
Copy link
Member

Currently, fit_MAP fails on graphs with shared variables when gradient_backend='jax'. This is because we are discarding the pytensor function wrapper that helps work with them. This PR rewrites shared variables into constants using _replace_shared_variables in this case. This is the same approach used in pymc.sampling.jax.get_jaxified_graph

@jessegrabowski jessegrabowski added the bug Something isn't working label May 2, 2025
if use_jax_gradients:
from pymc.sampling.jax import _replace_shared_variables

[loss] = _replace_shared_variables([loss])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw this is the sort of stuff that can lead to big constant foldings in JAX

@jessegrabowski jessegrabowski force-pushed the shared-vars-map-jax branch from 0d5c33c to f1bb479 Compare May 2, 2025 20:27
@jessegrabowski jessegrabowski merged commit 413a4cb into pymc-devs:main May 2, 2025
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants