|
| 1 | +import logging |
| 2 | +import warnings as _warnings |
| 3 | + |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from typing import Literal |
| 6 | + |
| 7 | +import arviz as az |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +from numpy.typing import NDArray |
| 11 | +from scipy.special import logsumexp |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class ImportanceSamplingResult: |
| 18 | + """container for importance sampling results""" |
| 19 | + |
| 20 | + samples: NDArray |
| 21 | + pareto_k: float | None = None |
| 22 | + warnings: list[str] = field(default_factory=list) |
| 23 | + method: str = "none" |
| 24 | + |
| 25 | + |
| 26 | +def importance_sampling( |
| 27 | + samples: NDArray, |
| 28 | + logP: NDArray, |
| 29 | + logQ: NDArray, |
| 30 | + num_draws: int, |
| 31 | + method: Literal["psis", "psir", "identity", "none"] | None, |
| 32 | + random_seed: int | None = None, |
| 33 | +) -> ImportanceSamplingResult: |
| 34 | + """Pareto Smoothed Importance Resampling (PSIR) |
| 35 | + This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS. |
| 36 | +
|
| 37 | + Parameters |
| 38 | + ---------- |
| 39 | + samples : NDArray |
| 40 | + samples from proposal distribution, shape (L, M, N) |
| 41 | + logP : NDArray |
| 42 | + log probability values of target distribution, shape (L, M) |
| 43 | + logQ : NDArray |
| 44 | + log probability values of proposal distribution, shape (L, M) |
| 45 | + num_draws : int |
| 46 | + number of draws to return where num_draws <= samples.shape[0] |
| 47 | + method : str, optional |
| 48 | + importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths. |
| 49 | + random_seed : int | None |
| 50 | +
|
| 51 | + Returns |
| 52 | + ------- |
| 53 | + ImportanceSamplingResult |
| 54 | + importance sampled draws and other info based on the specified method |
| 55 | +
|
| 56 | + Future work! |
| 57 | + ---------- |
| 58 | + - Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019) |
| 59 | + - Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018) |
| 60 | + - Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms. |
| 61 | +
|
| 62 | + References |
| 63 | + ---------- |
| 64 | + Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668 |
| 65 | +
|
| 66 | + Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538 |
| 67 | +
|
| 68 | + Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49. |
| 69 | + """ |
| 70 | + |
| 71 | + warnings = [] |
| 72 | + num_paths, _, N = samples.shape |
| 73 | + |
| 74 | + if method == "none": |
| 75 | + warnings.append( |
| 76 | + "Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability." |
| 77 | + ) |
| 78 | + return ImportanceSamplingResult(samples=samples, warnings=warnings) |
| 79 | + else: |
| 80 | + samples = samples.reshape(-1, N) |
| 81 | + logP = logP.ravel() |
| 82 | + logQ = logQ.ravel() |
| 83 | + |
| 84 | + # adjust log densities |
| 85 | + log_I = np.log(num_paths) |
| 86 | + logP -= log_I |
| 87 | + logQ -= log_I |
| 88 | + logiw = logP - logQ |
| 89 | + |
| 90 | + with _warnings.catch_warnings(): |
| 91 | + _warnings.filterwarnings( |
| 92 | + "ignore", category=RuntimeWarning, message="overflow encountered in exp" |
| 93 | + ) |
| 94 | + if method == "psis": |
| 95 | + replace = False |
| 96 | + logiw, pareto_k = az.psislw(logiw) |
| 97 | + elif method == "psir": |
| 98 | + replace = True |
| 99 | + logiw, pareto_k = az.psislw(logiw) |
| 100 | + elif method == "identity": |
| 101 | + replace = False |
| 102 | + pareto_k = None |
| 103 | + else: |
| 104 | + raise ValueError(f"Invalid importance sampling method: {method}") |
| 105 | + |
| 106 | + # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI. |
| 107 | + # Pareto k may not be a good diagnostic for Pathfinder. |
| 108 | + # TODO: Find replacement diagnostics for Pathfinder. |
| 109 | + |
| 110 | + p = np.exp(logiw - logsumexp(logiw)) |
| 111 | + rng = np.random.default_rng(random_seed) |
| 112 | + |
| 113 | + try: |
| 114 | + resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0) |
| 115 | + return ImportanceSamplingResult( |
| 116 | + samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method |
| 117 | + ) |
| 118 | + except ValueError as e1: |
| 119 | + if "Fewer non-zero entries in p than size" in str(e1): |
| 120 | + num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum() |
| 121 | + warnings.append( |
| 122 | + f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling." |
| 123 | + ) |
| 124 | + try: |
| 125 | + resampled = rng.choice( |
| 126 | + samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0 |
| 127 | + ) |
| 128 | + return ImportanceSamplingResult( |
| 129 | + samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method |
| 130 | + ) |
| 131 | + except ValueError as e2: |
| 132 | + logger.error( |
| 133 | + "Importance sampling failed even with psir importance sampling. " |
| 134 | + "This might indicate invalid probability weights or insufficient valid samples." |
| 135 | + ) |
| 136 | + raise ValueError( |
| 137 | + "Importance sampling failed for both with and without replacement" |
| 138 | + ) from e2 |
| 139 | + raise |
0 commit comments