Skip to content

Commit 53df690

Browse files
authored
Added initvals to parameters, constants and observations to returnvalue for pathfinder and cleaned relevant docs a bit (#447)
1 parent 4d65ea0 commit 53df690

File tree

7 files changed

+56
-18
lines changed

7 files changed

+56
-18
lines changed

docs/api_reference.rst

+3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ Inference
2323
.. autosummary::
2424
:toctree: generated/
2525

26+
find_MAP
2627
fit
28+
fit_laplace
29+
fit_pathfinder
2730

2831

2932
Distributions

pymc_extras/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515

1616
from pymc_extras import gp, statespace, utils
1717
from pymc_extras.distributions import *
18-
from pymc_extras.inference.find_map import find_MAP
19-
from pymc_extras.inference.fit import fit
20-
from pymc_extras.inference.laplace import fit_laplace
18+
from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
2119
from pymc_extras.model.marginal.marginal_model import (
2220
MarginalModel,
2321
marginalize,

pymc_extras/inference/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
from pymc_extras.inference.find_map import find_MAP
1616
from pymc_extras.inference.fit import fit
17+
from pymc_extras.inference.laplace import fit_laplace
18+
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
1719

18-
__all__ = ["fit"]
20+
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]

pymc_extras/inference/fit.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,20 @@
1616

1717
def fit(method: str, **kwargs) -> az.InferenceData:
1818
"""
19-
Fit a model with an inference algorithm
19+
Fit a model with an inference algorithm.
20+
See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.
2021
2122
Parameters
2223
----------
2324
method : str
2425
Which inference method to run.
2526
Supported: pathfinder or laplace
2627
27-
kwargs are passed on.
28+
kwargs: keyword arguments are passed on to the inference method.
2829
2930
Returns
3031
-------
31-
arviz.InferenceData
32+
:class:`~arviz.InferenceData`
3233
"""
3334
if method == "pathfinder":
3435
from pymc_extras.inference.pathfinder import fit_pathfinder

pymc_extras/inference/laplace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def fit_laplace(
509509
510510
Returns
511511
-------
512-
idata: az.InferenceData
512+
:class:`~arviz.InferenceData`
513513
An InferenceData object containing the approximated posterior samples.
514514
515515
Examples

pymc_extras/inference/pathfinder/pathfinder.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import collections
1617
import logging
1718
import time
@@ -24,7 +25,6 @@
2425

2526
import arviz as az
2627
import filelock
27-
import jax
2828
import numpy as np
2929
import pymc as pm
3030
import pytensor
@@ -43,7 +43,6 @@
4343
find_rng_nodes,
4444
reseed_rngs,
4545
)
46-
from pymc.sampling.jax import get_jaxified_graph
4746
from pymc.util import (
4847
CustomProgress,
4948
RandomSeed,
@@ -64,6 +63,7 @@
6463
# TODO: change to typing.Self after Python versions greater than 3.10
6564
from typing_extensions import Self
6665

66+
from pymc_extras.inference.laplace import add_data_to_inferencedata
6767
from pymc_extras.inference.pathfinder.importance_sampling import (
6868
importance_sampling as _importance_sampling,
6969
)
@@ -99,6 +99,8 @@ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Ca
9999
A JAX function that computes the log-probability of a PyMC model with ravelled inputs.
100100
"""
101101

102+
from pymc.sampling.jax import get_jaxified_graph
103+
102104
# TODO: JAX: test if we should get jaxified graph of dlogp as well
103105
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
104106
model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, ()
@@ -218,6 +220,10 @@ def convert_flat_trace_to_idata(
218220
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
219221

220222
elif inference_backend == "blackjax":
223+
import jax
224+
225+
from pymc.sampling.jax import get_jaxified_graph
226+
221227
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
222228
result = jax.vmap(jax.vmap(jax_fn))(
223229
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
@@ -1627,6 +1633,7 @@ def fit_pathfinder(
16271633
inference_backend: Literal["pymc", "blackjax"] = "pymc",
16281634
pathfinder_kwargs: dict = {},
16291635
compile_kwargs: dict = {},
1636+
initvals: dict | None = None,
16301637
) -> az.InferenceData:
16311638
"""
16321639
Fit the Pathfinder Variational Inference algorithm.
@@ -1662,12 +1669,12 @@ def fit_pathfinder(
16621669
importance_sampling : str, None, optional
16631670
Method to apply sampling based on log importance weights (logP - logQ).
16641671
Options are:
1665-
"psis" : Pareto Smoothed Importance Sampling (default)
1666-
Recommended for more stable results.
1667-
"psir" : Pareto Smoothed Importance Resampling
1668-
Less stable than PSIS.
1669-
"identity" : Applies log importance weights directly without resampling.
1670-
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1672+
1673+
- "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable.
1674+
- "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS.
1675+
- "identity" : Applies log importance weights directly without resampling.
1676+
- None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1677+
16711678
progressbar : bool, optional
16721679
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
16731680
random_seed : RandomSeed, optional
@@ -1682,10 +1689,13 @@ def fit_pathfinder(
16821689
Additional keyword arguments for the Pathfinder algorithm.
16831690
compile_kwargs
16841691
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
1692+
initvals: dict | None = None
1693+
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
1694+
If None, the model's default initial values are used.
16851695
16861696
Returns
16871697
-------
1688-
arviz.InferenceData
1698+
:class:`~arviz.InferenceData`
16891699
The inference data containing the results of the Pathfinder algorithm.
16901700
16911701
References
@@ -1695,6 +1705,14 @@ def fit_pathfinder(
16951705

16961706
model = modelcontext(model)
16971707

1708+
if initvals is not None:
1709+
model = pm.model.fgraph.clone_model(model) # Create a clone of the model
1710+
for (
1711+
rv_name,
1712+
ivals,
1713+
) in initvals.items(): # Set the initial values for the variables in the clone
1714+
model.set_initval(model.named_vars[rv_name], ivals)
1715+
16981716
valid_importance_sampling = {"psis", "psir", "identity", None}
16991717

17001718
if importance_sampling is not None:
@@ -1734,6 +1752,7 @@ def fit_pathfinder(
17341752
pathfinder_samples = mp_result.samples
17351753
elif inference_backend == "blackjax":
17361754
import blackjax
1755+
import jax
17371756

17381757
if version.parse(blackjax.__version__).major < 1:
17391758
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
@@ -1772,4 +1791,7 @@ def fit_pathfinder(
17721791
model=model,
17731792
importance_sampling=importance_sampling,
17741793
)
1794+
1795+
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1796+
17751797
return idata

tests/test_pathfinder.py

+12
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,15 @@ def test_pathfinder_importance_sampling(importance_sampling):
200200
assert idata.posterior["mu"].shape == (1, num_draws)
201201
assert idata.posterior["tau"].shape == (1, num_draws)
202202
assert idata.posterior["theta"].shape == (1, num_draws, 8)
203+
204+
205+
def test_pathfinder_initvals():
206+
# Run a model with an ordered transform that will fail unless initvals are in place
207+
with pm.Model() as mdl:
208+
pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
209+
idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})
210+
211+
# Check that the samples are ordered to make sure transform was applied
212+
assert np.all(
213+
idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
214+
)

0 commit comments

Comments
 (0)