Skip to content

Commit eb1183a

Browse files
authored
PyMC/PyTensor Implementation of Pathfinder VI (pymc-devs#387)
* renamed samples argument name and pathfinder variables to avoid confusion * Minor changes made to the `fit_pathfinder` function and added test `fit_pathfinder` - Edited `fit_pathfinder` to produce `pathfinder_state`, `pathfinder_info`, `pathfinder_samples` and `pathfinder_idata` for closer examination of the outputs. - Changed the `num_samples` argument name to `num_draws` to avoid `TypeError` got multiple values for keyword argument 'num_samples'. - Initial points are automatically set to jitter as jitter is required for pathfinder. Extras - New function 'get_jaxified_logp_ravel_inputs' to simplify previous code structure in fit_pathfinder. Tests - Added extra test for pathfinder to test pathfinder_info variables and pathfinder_idata are consistent for a given random seed. * extract additional pathfinder objects from high level API for debugging * changed pathfinder samples argument to num_draws * feat(pathfinder): add PyMC-based Pathfinder VI implementation Add a new PyMC-based implementation of Pathfinder VI that uses PyTensor operations which provides support for both PyMC and BlackJAX backends in fit_pathfinder. * Multipath Pathfinder VI implementation in pymc-experimental - Implemented in to support running multiple Pathfinder instances in parallel. - Implemented function in for Pareto Smoothed Importance Resampling (PSIR). - Moved relevant pathfinder files into the directory. - Updated tests to reflect changes in the Pathfinder implementation and added tests for new functionalities. * Added type hints and epsilon parameter to fit_pathfinder * Removed initial point values (l=0) to reduce iterations. Simplified and . * Added placeholder/reminder to remove jax dependency when converting trace data to InferenceData * Sync updates with draft PR pymc-devs#386. \n- Added pytensor.function for bfgs_sample * Reduced size of compute graph with pathfinder_body_fn Summaryh of changes: - Remove multiprocessing code in favour of reusing compiled for each path - takes only random_seed as argument for each path - Compute graph significantly smaller by using pure pytensor op and symoblic variables - Added LBFGSOp to compile with pytensor.function - Cleaned up codes using pytensor variables * - Added TODO comments for implementing Taylor approximation methods: and . - Corrected the dimensions in comments for matrices Q and R in the function. - Uumerical stability in the calculation by changing from to . * fix: correct posterior approximations in Pathfinder VI Fixed incorrect and inconsistent posterior approximations in the Pathfinder VI algorithm by: 1. Adding missing parentheses in the phi calculation to ensure proper order of operations in matrix multiplications 2. Changing the sign in mu calculation from 'x +' to 'x -' to match Stan's implementation (which differs from the original paper) The resulting changes now make the posterior approximations more reliable. * feat: Add dense BFGS sampling for Pathfinder VI Implements both sparse and dense BFGS sampling approaches for Pathfinder VI: - Adds bfgs_sample_dense for cases where 2*maxcor >= num_params. - Moved existing and computations to bfgs_sample_sparse, making the sparse use cases more explicit. Other changes: - Sets default maxcor=5 instead of dynamic sizing based on parameters Dense approximations are recommended when the target distribution has higher dependencies among the parameters. * feat: improve Pathfinder performance and compatibility Bigger changes: - Made pmx.fit compatible with method='pathfinder' - Remove JAX dependency when inference_backend='pymc' to support Windows users - Improve runtime performance by setting trust_input=True for compiled functions Minor changes: - Change default num_paths from 1 to 4 for stable and reliable approximations - Change LBFGS code using dataclasses - Update tests to handle both PyMC and BlackJAX backends * minor: improve error handling in Pathfinder VI - Add LBFGSInitFailed exception for failed LBFGS initialisation - Skip failed paths in multipath_pathfinder and track number of failures - Handle NaN values from Cholesky decompsition in bfgs_sample - Add checks for numericl stabilty in matrix operations Slight performance improvements: - Set allow_gc=False in scan ops - Use FAST_RUN mode consistently * Progress bar and other minor changes Major: - Added progress bar support. Minor - Added exception for non-finite log prob values - Removed . - Allowed maxcor argument to be None, and dynamically set based on the number of model parameters. - Improved logging to inform users about failed paths and lbfgs initialisation. * set maxcor to max(5, floor(N / 1.9)). max=1 will cause error * Refactor Pathfinder VI: Default to PSIS, Add Concurrency, and Improved Computational Performance - Significantly computational efficiency by combining 3 computational graphs into 1 larger compile. Removed non-shared inputs and used with for significant performance gains. - Set default importance sampling method to 'psis' for more stable posterior results, avoiding local peaks seen with 'psir'. - Introduce concurrency options ('thread' and 'process') for multithreading and multiprocessing. Defaults to No concurrency as there haven't been any/or much reduction to the compute time. - Adjusted default from 8 to 4 and from 1.0 to 2.0 and maxcor to max(3*log(N), 5). This default setting lessens computational time and and the degree by which the posterior variance is being underestimated. * Improvements to Importance Sampling and InferenceData shape - Handle different importance sampling methods for reshaping and adjusting log densities. - Modified to return InferenceData with chain dim of size num_paths when * Display summary of results, Improve error handling, General improvements Changes: - Add rich table summary display for results - Added PathStatus and LBFGSStatus for error handling, status tracking and displaying results - Changed importance_sampling return type to ImportanceSamplingResult - Changed multipath_pathfinder return type to MultiPathfinderResult - Added dataclass containers for results (ImportanceSamplingResult, PathfinderResult, MultiPathfinderResult) - Refactored LBFGS by removing PyTensor Op classes in favor of pure functions - Added timing and configuration tracking - Improve concurrency with better error handling - Improved docstrings and type hints - Simplified logp and gradient computation by combining into single function - Added compile_kwargs parameter for pytensor compilation options * Move pathfinder module to pymc_extras - Move pathfinder module from pymc_experimental to pymc_extras - Update directory structure to match upstream repository * Improve pathfinder error handling and type hints - Add proper type hints throughout pathfinder module - Improve error handling in concurrent execution paths - Better handling of when all paths are fail by displaying results before Assertion - Changed Australian English spelling to US - Update compile_pymc usage to handle deprecation warning - Add tests for concurrent execution and seed reproducibility - Clean up imports and remove redundant code - Improve docstrings and error messages * fix: Use typing_extensions.Self for Python 3.10 compatibility
1 parent 96f7a2e commit eb1183a

File tree

7 files changed

+2213
-145
lines changed

7 files changed

+2213
-145
lines changed

pymc_extras/inference/fit.py

-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from importlib.util import find_spec
1514

1615

1716
def fit(method, **kwargs):
@@ -31,9 +30,6 @@ def fit(method, **kwargs):
3130
arviz.InferenceData
3231
"""
3332
if method == "pathfinder":
34-
if find_spec("blackjax") is None:
35-
raise RuntimeError("Need BlackJAX to use `pathfinder`")
36-
3733
from pymc_extras.inference.pathfinder import fit_pathfinder
3834

3935
return fit_pathfinder(**kwargs)

pymc_extras/inference/pathfinder.py

-134
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
2+
3+
__all__ = ["fit_pathfinder"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import logging
2+
import warnings as _warnings
3+
4+
from dataclasses import dataclass, field
5+
from typing import Literal
6+
7+
import arviz as az
8+
import numpy as np
9+
10+
from numpy.typing import NDArray
11+
from scipy.special import logsumexp
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
@dataclass(frozen=True)
17+
class ImportanceSamplingResult:
18+
"""container for importance sampling results"""
19+
20+
samples: NDArray
21+
pareto_k: float | None = None
22+
warnings: list[str] = field(default_factory=list)
23+
method: str = "none"
24+
25+
26+
def importance_sampling(
27+
samples: NDArray,
28+
logP: NDArray,
29+
logQ: NDArray,
30+
num_draws: int,
31+
method: Literal["psis", "psir", "identity", "none"] | None,
32+
random_seed: int | None = None,
33+
) -> ImportanceSamplingResult:
34+
"""Pareto Smoothed Importance Resampling (PSIR)
35+
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.
36+
37+
Parameters
38+
----------
39+
samples : NDArray
40+
samples from proposal distribution, shape (L, M, N)
41+
logP : NDArray
42+
log probability values of target distribution, shape (L, M)
43+
logQ : NDArray
44+
log probability values of proposal distribution, shape (L, M)
45+
num_draws : int
46+
number of draws to return where num_draws <= samples.shape[0]
47+
method : str, optional
48+
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
49+
random_seed : int | None
50+
51+
Returns
52+
-------
53+
ImportanceSamplingResult
54+
importance sampled draws and other info based on the specified method
55+
56+
Future work!
57+
----------
58+
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
59+
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
60+
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.
61+
62+
References
63+
----------
64+
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668
65+
66+
Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538
67+
68+
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
69+
"""
70+
71+
warnings = []
72+
num_paths, _, N = samples.shape
73+
74+
if method == "none":
75+
warnings.append(
76+
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
77+
)
78+
return ImportanceSamplingResult(samples=samples, warnings=warnings)
79+
else:
80+
samples = samples.reshape(-1, N)
81+
logP = logP.ravel()
82+
logQ = logQ.ravel()
83+
84+
# adjust log densities
85+
log_I = np.log(num_paths)
86+
logP -= log_I
87+
logQ -= log_I
88+
logiw = logP - logQ
89+
90+
with _warnings.catch_warnings():
91+
_warnings.filterwarnings(
92+
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
93+
)
94+
if method == "psis":
95+
replace = False
96+
logiw, pareto_k = az.psislw(logiw)
97+
elif method == "psir":
98+
replace = True
99+
logiw, pareto_k = az.psislw(logiw)
100+
elif method == "identity":
101+
replace = False
102+
pareto_k = None
103+
else:
104+
raise ValueError(f"Invalid importance sampling method: {method}")
105+
106+
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
107+
# Pareto k may not be a good diagnostic for Pathfinder.
108+
# TODO: Find replacement diagnostics for Pathfinder.
109+
110+
p = np.exp(logiw - logsumexp(logiw))
111+
rng = np.random.default_rng(random_seed)
112+
113+
try:
114+
resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
115+
return ImportanceSamplingResult(
116+
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
117+
)
118+
except ValueError as e1:
119+
if "Fewer non-zero entries in p than size" in str(e1):
120+
num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum()
121+
warnings.append(
122+
f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling."
123+
)
124+
try:
125+
resampled = rng.choice(
126+
samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0
127+
)
128+
return ImportanceSamplingResult(
129+
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
130+
)
131+
except ValueError as e2:
132+
logger.error(
133+
"Importance sampling failed even with psir importance sampling. "
134+
"This might indicate invalid probability weights or insufficient valid samples."
135+
)
136+
raise ValueError(
137+
"Importance sampling failed for both with and without replacement"
138+
) from e2
139+
raise

0 commit comments

Comments
 (0)