Skip to content

PyMC Implementation of Pathfinder VI #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed

Conversation

aphc14
Copy link
Contributor

@aphc14 aphc14 commented Oct 31, 2024

Summary:

  • Adds a PyMC implementation of Pathfinder Variational Inference using PyTensor operations. The new implementation allows users to choose between PyMC and BlackJAX backends while maintaining the same API.
  • Added lbfgs.py module implementing L-BFGS optimisation with history tracking
  • Extended pathfinder.py with PyMC implementation using PyTensor operations.

Note: Another draft PR will be sent that focuses on a PyTensor symbolic implementation using pytensor.function. I've sent two PR drafts to get feedback on which version would be better.

import pymc as pm
import pymc_experimental as pmx

with pm.Model() as model:
    # ... model definition ...
    idata = pmx.fit(
        method="pathfinder",
        inference_backend="pymc",  # or "blackjax"
        random_seed=42
    )
with model:
    # eight_schools_model
    idata = pmx.fit(model=model, method="pathfinder", random_seed=41, inference_backend="pymc")
    
    # New implementation now passes this assertion! :)
    # FIXME: pathfinder doesn't find a reasonable mean! Fix bug or choose model pathfinder can handle
    np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.0)

    # But, it also fails this :(
    # FIXME: now the tau is being underestimated. getting tau around 1.5.
    # np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=0.5)

`fit_pathfinder`
- Edited `fit_pathfinder` to produce `pathfinder_state`, `pathfinder_info`, `pathfinder_samples` and `pathfinder_idata` for closer examination of the outputs.
- Changed the `num_samples` argument name to `num_draws` to avoid `TypeError` got multiple values for keyword argument 'num_samples'.
- Initial points are automatically set to jitter as jitter is required for pathfinder.

Extras
- New function 'get_jaxified_logp_ravel_inputs' to simplify previous code structure in fit_pathfinder.

Tests
- Added extra test for pathfinder to test pathfinder_info variables and pathfinder_idata  are consistent for a given random seed.
Add a new PyMC-based implementation of Pathfinder VI that uses PyTensor operations which provides support for both PyMC and BlackJAX backends in fit_pathfinder.
@twiecki
Copy link
Member

twiecki commented Oct 31, 2024

This looks great @aphc14, can you find me on linkedin?

- Implemented  in  to support running multiple Pathfinder instances in parallel.
- Implemented  function in  for Pareto Smoothed Importance Resampling (PSIR).
- Moved relevant pathfinder files into the  directory.
- Updated tests to reflect changes in the Pathfinder implementation and added tests for new functionalities.
@aphc14
Copy link
Contributor Author

aphc14 commented Nov 4, 2024

@twiecki yup, added you :)
@fonnesbeck, FYI, Multipath Pathfinder has just been implemented!

aphc14 added a commit to aphc14/pymc-extras that referenced this pull request Nov 7, 2024
@aphc14 aphc14 closed this Nov 11, 2024
@aphc14
Copy link
Contributor Author

aphc14 commented Nov 11, 2024

closed this in favour of #387

fonnesbeck pushed a commit that referenced this pull request Jan 27, 2025
* renamed samples argument name and pathfinder variables to avoid confusion

* Minor changes made to the `fit_pathfinder` function and added test

`fit_pathfinder`
- Edited `fit_pathfinder` to produce `pathfinder_state`, `pathfinder_info`, `pathfinder_samples` and `pathfinder_idata` for closer examination of the outputs.
- Changed the `num_samples` argument name to `num_draws` to avoid `TypeError` got multiple values for keyword argument 'num_samples'.
- Initial points are automatically set to jitter as jitter is required for pathfinder.

Extras
- New function 'get_jaxified_logp_ravel_inputs' to simplify previous code structure in fit_pathfinder.

Tests
- Added extra test for pathfinder to test pathfinder_info variables and pathfinder_idata  are consistent for a given random seed.

* extract additional pathfinder objects from high level API for debugging

* changed pathfinder samples argument to  num_draws

* feat(pathfinder): add PyMC-based Pathfinder VI implementation

Add a new PyMC-based implementation of Pathfinder VI that uses PyTensor operations which provides support for both PyMC and BlackJAX backends in fit_pathfinder.

* Multipath Pathfinder VI implementation in pymc-experimental

- Implemented  in  to support running multiple Pathfinder instances in parallel.
- Implemented  function in  for Pareto Smoothed Importance Resampling (PSIR).
- Moved relevant pathfinder files into the  directory.
- Updated tests to reflect changes in the Pathfinder implementation and added tests for new functionalities.

* Added type hints and epsilon parameter to fit_pathfinder

* Removed initial point values (l=0) to reduce iterations. Simplified  and .

* Added placeholder/reminder to remove jax dependency when converting trace data to InferenceData

* Sync updates with draft PR #386. \n- Added pytensor.function for bfgs_sample

* Reduced size of compute graph with pathfinder_body_fn

Summaryh of changes:
- Remove multiprocessing code in favour of reusing compiled  for each path
-  takes only random_seed as argument for each path
- Compute graph significantly smaller by using pure pytensor op and symoblic variables
- Added LBFGSOp to compile with pytensor.function
- Cleaned up codes using pytensor variables

* - Added TODO comments for implementing Taylor approximation methods:  and .
- Corrected the dimensions in comments for matrices Q and R in the  function.
- Uumerical stability in the  calculation by changing from  to .

* fix: correct posterior approximations in Pathfinder VI

Fixed incorrect and inconsistent posterior approximations in the Pathfinder VI
algorithm by:

1. Adding missing parentheses in the phi calculation to ensure proper order
   of operations in matrix multiplications
2. Changing the sign in mu calculation from 'x +' to 'x -' to match Stan's
   implementation (which differs from the original paper)

The resulting changes now make the posterior approximations more reliable.

* feat: Add dense BFGS sampling for Pathfinder VI

Implements both sparse and dense BFGS sampling approaches for Pathfinder VI:
- Adds bfgs_sample_dense for cases where 2*maxcor >= num_params.
- Moved existing  and  computations to bfgs_sample_sparse, making the sparse use cases more explicit.

Other changes:
- Sets default maxcor=5 instead of dynamic sizing based on parameters

Dense approximations are recommended when the target distribution has higher dependencies among the parameters.

* feat: improve Pathfinder performance and compatibility

Bigger changes:
- Made pmx.fit compatible with method='pathfinder'
- Remove JAX dependency when inference_backend='pymc' to support Windows users
- Improve runtime performance by setting trust_input=True for compiled functions

Minor changes:
- Change default num_paths from 1 to 4 for stable and reliable approximations
- Change LBFGS code using dataclasses
- Update tests to handle both PyMC and BlackJAX backends

* minor: improve error handling in Pathfinder VI

- Add LBFGSInitFailed exception for failed LBFGS initialisation
- Skip failed paths in multipath_pathfinder and track number of failures
- Handle NaN values from Cholesky decompsition in bfgs_sample
- Add checks for numericl stabilty in matrix operations

Slight performance improvements:
- Set allow_gc=False in scan ops
- Use FAST_RUN mode consistently

* Progress bar and other minor changes

Major:
  - Added progress bar support.

Minor
  - Added  exception for non-finite log prob values
  - Removed .
  - Allowed maxcor argument to be None, and dynamically set based on the number of model parameters.
  - Improved logging to inform users about failed paths and lbfgs initialisation.

* set maxcor to max(5, floor(N / 1.9)). max=1 will cause error

* Refactor Pathfinder VI: Default to PSIS, Add Concurrency, and Improved Computational Performance

- Significantly computational efficiency by combining 3 computational graphs into 1 larger compile. Removed non-shared inputs and used  with  for significant performance gains.
- Set default importance sampling method to 'psis' for more stable posterior results, avoiding local peaks seen with 'psir'.
- Introduce concurrency options ('thread' and 'process') for multithreading and multiprocessing. Defaults to No concurrency as there haven't been any/or much reduction to the compute time.
- Adjusted default  from 8 to 4 and  from 1.0 to 2.0 and maxcor to max(3*log(N), 5). This default setting lessens computational time and and the degree by which the posterior variance is being underestimated.

* Improvements to Importance Sampling and InferenceData shape

- Handle different importance sampling methods for reshaping and adjusting log densities.
- Modified  to return InferenceData with chain dim of size num_paths when

* Display summary of results, Improve error handling, General improvements

Changes:
- Add rich table summary display for results
- Added PathStatus and LBFGSStatus for error handling, status tracking and displaying results
- Changed importance_sampling return type to ImportanceSamplingResult
- Changed multipath_pathfinder return type to MultiPathfinderResult
- Added dataclass containers for results (ImportanceSamplingResult, PathfinderResult, MultiPathfinderResult)
- Refactored LBFGS by removing PyTensor Op classes in favor of pure functions
- Added timing and configuration tracking
- Improve concurrency with better error handling
- Improved docstrings and type hints
- Simplified logp and gradient computation by combining into single function
- Added compile_kwargs parameter for pytensor compilation options

* Move pathfinder module to pymc_extras

- Move pathfinder module from pymc_experimental to pymc_extras
- Update directory structure to match upstream repository

* Improve pathfinder error handling and type hints

- Add proper type hints throughout pathfinder module
- Improve error handling in concurrent execution paths
- Better handling of when all paths are fail by displaying results before Assertion
- Changed Australian English spelling to US
- Update compile_pymc usage to handle deprecation warning
- Add tests for concurrent execution and seed reproducibility
- Clean up imports and remove redundant code
- Improve docstrings and error messages

* fix: Use typing_extensions.Self for Python 3.10 compatibility
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants