Skip to content

Commit cb3d501

Browse files
committed
Add opt_sample
1 parent 5055262 commit cb3d501

File tree

7 files changed

+135
-0
lines changed

7 files changed

+135
-0
lines changed

docs/api_reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ methods in the current release of PyMC experimental.
1515
MarginalModel
1616
marginalize
1717
model_builder.ModelBuilder
18+
opt_sample
1819

1920
Inference
2021
=========

pymc_experimental/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pymc_experimental.inference.fit import fit
1919
from pymc_experimental.model.marginal.marginal_model import MarginalModel, marginalize
2020
from pymc_experimental.model.model_api import as_model
21+
from pymc_experimental.sampling.mcmc import opt_sample
2122
from pymc_experimental.version import __version__
2223

2324
_log = logging.getLogger("pmx")

pymc_experimental/sampling/__init__.py

Whitespace-only changes.

pymc_experimental/sampling/mcmc.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import sys
2+
3+
from pymc.model.core import Model
4+
from pymc.sampling.mcmc import sample
5+
from pytensor.graph.rewriting.basic import GraphRewriter
6+
7+
from pymc_experimental.sampling.optimizations.optimize import (
8+
TAGS_TYPE,
9+
optimize_model_for_mcmc_sampling,
10+
)
11+
12+
13+
def opt_sample(
14+
*args,
15+
model: Model | None = None,
16+
include: TAGS_TYPE = ("default",),
17+
exclude: TAGS_TYPE = None,
18+
rewriter: GraphRewriter | None = None,
19+
verbose: bool = False,
20+
**kwargs,
21+
):
22+
"""Sample from a model after applying optimizations.
23+
24+
Parameters
25+
----------
26+
model : Model, optinoal
27+
The model to sample from. If None, use the model associated with the context.
28+
include : TAGS_TYPE
29+
The tags to include in the optimizations. Ignored if `rewriter` is not None.
30+
exclude : TAGS_TYPE
31+
The tags to exclude from the optimizations. Ignored if `rewriter` is not None.
32+
rewriter : RewriteDatabaseQuery (optional)
33+
The rewriter to use. If None, use the default rewriter with the given `include` and `exclude` tags.
34+
verbose : bool, default=False
35+
Print information about the optimizations applied.
36+
*args, **kwargs:
37+
Passed to `pm.sample`
38+
39+
Returns
40+
-------
41+
sample_output:
42+
The output of `pm.sample`
43+
44+
Examples
45+
--------
46+
.. code:: python
47+
import pymc as pm
48+
import pymc_experimental as pmx
49+
50+
with pm.Model() as m:
51+
p = pm.Beta("p", 1, 1, shape=(1000,))
52+
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
53+
54+
idata = pmx.opt_sample(verbose=True)
55+
"""
56+
if kwargs.get("step", None) is not None:
57+
raise ValueError(
58+
"The `step` argument is not supported in `opt_sample`, as custom steps would refer to the original model.\n"
59+
"You can manually transform the model with `pymc_experimental.sampling.optimizations.optimize_model_for_mcmc_sampling` "
60+
"and then define the custom steps and forward them to `pymc.sample`."
61+
)
62+
63+
opt_model, rewrite_counters = optimize_model_for_mcmc_sampling(
64+
model, include=include, exclude=exclude, rewriter=rewriter
65+
)
66+
67+
if verbose:
68+
applied_opt = False
69+
for rewrite_counter in rewrite_counters:
70+
for rewrite, counts in rewrite_counter.items():
71+
applied_opt = True
72+
print(f"Applied optimization: {rewrite} {counts}x", file=sys.stdout)
73+
if not applied_opt:
74+
print("No optimizations applied", file=sys.stdout)
75+
76+
return sample(*args, model=opt_model, **kwargs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from collections import Counter
2+
from collections.abc import Sequence
3+
from typing import TypeAlias
4+
5+
from pymc.model.core import Model, modelcontext
6+
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph
7+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery
8+
9+
posterior_optimization_db = EquilibriumDB()
10+
posterior_optimization_db.failure_callback = None # Raise an error if an optimization fails
11+
posterior_optimization_db.name = "posterior_optimization_db"
12+
13+
TAGS_TYPE: TypeAlias = str | Sequence[str] | None
14+
15+
16+
def optimize_model_for_mcmc_sampling(
17+
model: Model,
18+
include: TAGS_TYPE = ("default",),
19+
exclude: TAGS_TYPE = None,
20+
rewriter=None,
21+
) -> tuple[Model, Sequence[Counter]]:
22+
if isinstance(include, str):
23+
include = (include,)
24+
if isinstance(exclude, str):
25+
exclude = (exclude,)
26+
27+
model = modelcontext(model)
28+
fgraph, _ = fgraph_from_model(model)
29+
30+
if rewriter is None:
31+
rewriter = posterior_optimization_db.query(
32+
RewriteDatabaseQuery(include=include, exclude=exclude)
33+
)
34+
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)
35+
36+
opt_model = model_from_fgraph(fgraph, mutate_fgraph=True)
37+
return opt_model, rewrite_counters

tests/sampling/mcmc/__init__.py

Whitespace-only changes.

tests/sampling/mcmc/test_mcmc.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
from pymc.distributions import Beta, Binomial, InverseGamma
4+
from pymc.model.core import Model
5+
from pymc.step_methods import Slice
6+
7+
from pymc_experimental import opt_sample
8+
9+
10+
def test_custom_step_raises():
11+
with Model() as m:
12+
a = InverseGamma("a", 1, 1)
13+
b = InverseGamma("b", 1, 1)
14+
p = Beta("p", a, b)
15+
y = Binomial("y", n=100, p=p, observed=99)
16+
17+
with pytest.raises(
18+
ValueError, match="The `step` argument is not supported in `opt_sample`"
19+
):
20+
opt_sample(step=Slice([a, b]))

0 commit comments

Comments
 (0)