-
-
Notifications
You must be signed in to change notification settings - Fork 61
Pathfinder: allow mode="JAX"
for pytensor backend compiler
#425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
There might not be a need to modify the pytensor component of the code to allow |
Can the perform method of that I think being able to dispatch to jax/numba is a good thing regardless of we are able to hook up to a separate implementation in blackjax. We should also prefer direct pytensor implementations (using existing atomic Ops vs making a new Op with a numpy perform) over writing new Ops for exactly this reason. There's also less maintenance overhead, and there might be optimization rewrites in this logp graph that are being hidden from pytensor. |
This is something you should be able to do with PyTensor vectorize / vectorize_graph directly |
Would like to extend the pytensor backend of Pathfinder to compile using JAX by setting
compile_kwargs=dict(mode="JAX")
inpmx.fit
. Not yet entirely sure what the speed advantage (if any) there is. However, I think the solution to the problem below might not be too difficult.A required fix may be to implement JAX conversion for the LogLike operator below. (The reason for having the
LogLike
Op
was to vectorise an existing compiledmodel.logp()
function which takes in a flattened array of the model parameters).pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py
Lines 693 to 716 in 00a4ca3
Minimum working example:
Output:
The text was updated successfully, but these errors were encountered: