Skip to content

Commit 3416258

Browse files
committed
Added initvals to parameters, constants and observations to returnvalue for pathfinder and cleaned relevant docs a bit
1 parent 4d65ea0 commit 3416258

File tree

7 files changed

+52
-17
lines changed

7 files changed

+52
-17
lines changed

docs/api_reference.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ Inference
2323
.. autosummary::
2424
:toctree: generated/
2525

26+
find_MAP
2627
fit
28+
fit_laplace
29+
fit_pathfinder
2730

2831

2932
Distributions

pymc_extras/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515

1616
from pymc_extras import gp, statespace, utils
1717
from pymc_extras.distributions import *
18-
from pymc_extras.inference.find_map import find_MAP
19-
from pymc_extras.inference.fit import fit
20-
from pymc_extras.inference.laplace import fit_laplace
18+
from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
2119
from pymc_extras.model.marginal.marginal_model import (
2220
MarginalModel,
2321
marginalize,

pymc_extras/inference/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
from pymc_extras.inference.find_map import find_MAP
1616
from pymc_extras.inference.fit import fit
17+
from pymc_extras.inference.laplace import fit_laplace
18+
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
1719

18-
__all__ = ["fit"]
20+
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]

pymc_extras/inference/fit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,20 @@
1616

1717
def fit(method: str, **kwargs) -> az.InferenceData:
1818
"""
19-
Fit a model with an inference algorithm
19+
Fit a model with an inference algorithm.
20+
See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.
2021
2122
Parameters
2223
----------
2324
method : str
2425
Which inference method to run.
2526
Supported: pathfinder or laplace
2627
27-
kwargs are passed on.
28+
kwargs: keyword arguments are passed on to the inference method.
2829
2930
Returns
3031
-------
31-
arviz.InferenceData
32+
:class:`~arviz.InferenceData`
3233
"""
3334
if method == "pathfinder":
3435
from pymc_extras.inference.pathfinder import fit_pathfinder

pymc_extras/inference/laplace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def fit_laplace(
509509
510510
Returns
511511
-------
512-
idata: az.InferenceData
512+
:class:`~arviz.InferenceData`
513513
An InferenceData object containing the approximated posterior samples.
514514
515515
Examples

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import collections
1617
import logging
1718
import time
@@ -24,7 +25,6 @@
2425

2526
import arviz as az
2627
import filelock
27-
import jax
2828
import numpy as np
2929
import pymc as pm
3030
import pytensor
@@ -64,6 +64,7 @@
6464
# TODO: change to typing.Self after Python versions greater than 3.10
6565
from typing_extensions import Self
6666

67+
from pymc_extras.inference.laplace import add_data_to_inferencedata
6768
from pymc_extras.inference.pathfinder.importance_sampling import (
6869
importance_sampling as _importance_sampling,
6970
)
@@ -218,6 +219,8 @@ def convert_flat_trace_to_idata(
218219
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
219220

220221
elif inference_backend == "blackjax":
222+
import jax
223+
221224
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
222225
result = jax.vmap(jax.vmap(jax_fn))(
223226
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
@@ -1627,6 +1630,7 @@ def fit_pathfinder(
16271630
inference_backend: Literal["pymc", "blackjax"] = "pymc",
16281631
pathfinder_kwargs: dict = {},
16291632
compile_kwargs: dict = {},
1633+
initvals: dict | None = None,
16301634
) -> az.InferenceData:
16311635
"""
16321636
Fit the Pathfinder Variational Inference algorithm.
@@ -1662,12 +1666,12 @@ def fit_pathfinder(
16621666
importance_sampling : str, None, optional
16631667
Method to apply sampling based on log importance weights (logP - logQ).
16641668
Options are:
1665-
"psis" : Pareto Smoothed Importance Sampling (default)
1666-
Recommended for more stable results.
1667-
"psir" : Pareto Smoothed Importance Resampling
1668-
Less stable than PSIS.
1669-
"identity" : Applies log importance weights directly without resampling.
1670-
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1669+
1670+
- "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable.
1671+
- "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS.
1672+
- "identity" : Applies log importance weights directly without resampling.
1673+
- None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1674+
16711675
progressbar : bool, optional
16721676
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
16731677
random_seed : RandomSeed, optional
@@ -1682,10 +1686,13 @@ def fit_pathfinder(
16821686
Additional keyword arguments for the Pathfinder algorithm.
16831687
compile_kwargs
16841688
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
1689+
initvals: dict | None = None
1690+
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
1691+
If None, the model's default initial values are used.
16851692
16861693
Returns
16871694
-------
1688-
arviz.InferenceData
1695+
:class:`~arviz.InferenceData`
16891696
The inference data containing the results of the Pathfinder algorithm.
16901697
16911698
References
@@ -1695,6 +1702,14 @@ def fit_pathfinder(
16951702

16961703
model = modelcontext(model)
16971704

1705+
if initvals is not None:
1706+
model = pm.model.fgraph.clone_model(model) # Create a clone of the model
1707+
for (
1708+
rv_name,
1709+
ivals,
1710+
) in initvals.items(): # Set the initial values for the variables in the clone
1711+
model.set_initval(model.named_vars[rv_name], ivals)
1712+
16981713
valid_importance_sampling = {"psis", "psir", "identity", None}
16991714

17001715
if importance_sampling is not None:
@@ -1734,6 +1749,7 @@ def fit_pathfinder(
17341749
pathfinder_samples = mp_result.samples
17351750
elif inference_backend == "blackjax":
17361751
import blackjax
1752+
import jax
17371753

17381754
if version.parse(blackjax.__version__).major < 1:
17391755
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
@@ -1772,4 +1788,7 @@ def fit_pathfinder(
17721788
model=model,
17731789
importance_sampling=importance_sampling,
17741790
)
1791+
1792+
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1793+
17751794
return idata

tests/test_pathfinder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,15 @@ def test_pathfinder_importance_sampling(importance_sampling):
200200
assert idata.posterior["mu"].shape == (1, num_draws)
201201
assert idata.posterior["tau"].shape == (1, num_draws)
202202
assert idata.posterior["theta"].shape == (1, num_draws, 8)
203+
204+
205+
def test_pathfinder_initvals():
206+
# Run a model with an ordered transform that will fail unless initvals are in place
207+
with pm.Model() as mdl:
208+
pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
209+
idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})
210+
211+
# Check that the samples are ordered to make sure transform was applied
212+
assert np.all(
213+
idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
214+
)

0 commit comments

Comments
 (0)