Skip to content

Commit 3c19a5e

Browse files
committed
Add opt_sample
1 parent 5055262 commit 3c19a5e

File tree

4 files changed

+67
-0
lines changed

4 files changed

+67
-0
lines changed

docs/api_reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ methods in the current release of PyMC experimental.
1515
MarginalModel
1616
marginalize
1717
model_builder.ModelBuilder
18+
opt_sample
1819

1920
Inference
2021
=========

pymc_experimental/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pymc_experimental.inference.fit import fit
1919
from pymc_experimental.model.marginal.marginal_model import MarginalModel, marginalize
2020
from pymc_experimental.model.model_api import as_model
21+
from pymc_experimental.sampling.mcmc import opt_sample
2122
from pymc_experimental.version import __version__
2223

2324
_log = logging.getLogger("pmx")

pymc_experimental/sampling/__init__.py

Whitespace-only changes.

pymc_experimental/sampling/mcmc.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import sys
2+
3+
from pymc.model.core import Model, modelcontext
4+
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph
5+
from pymc.sampling.mcmc import sample
6+
from pytensor.graph.rewriting.basic import GraphRewriter
7+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery
8+
9+
posterior_optimization_db = EquilibriumDB()
10+
posterior_optimization_db.failure_callback = None # Raise an error if an optimization fails
11+
posterior_optimization_db.name = "posterior_optimization_db"
12+
13+
14+
def opt_sample(
15+
*args,
16+
model: Model | None = None,
17+
rewriter: GraphRewriter | None = None,
18+
verbose: bool = False,
19+
**kwargs,
20+
):
21+
"""Sample from a model after applying optimizations.
22+
23+
Parameters
24+
----------
25+
model : Model (optional)
26+
The model to sample from. If None, use the model associated with the context.
27+
rewriter : RewriteDatabaseQuery (optional)
28+
The rewriter to use. If None, use the default rewriter.
29+
verbose : bool, default=False
30+
Print information about the optimizations applied.
31+
*args, **kwargs:
32+
Passed to `pm.sample`
33+
34+
Returns
35+
-------
36+
sample_output:
37+
The output of `pm.sample`
38+
39+
Examples
40+
--------
41+
.. code:: python
42+
import pymc as pm
43+
import pymc_experimental as pmx
44+
45+
with pm.Model() as m:
46+
p = pm.Beta("p", 1, 1)
47+
y = pm.Binomial("y", n=10, p=p, observed=5)
48+
49+
idata = pmx.opt_sample(verbose=True)
50+
"""
51+
52+
model = modelcontext(model)
53+
fgraph, _ = fgraph_from_model(model)
54+
55+
if rewriter is None:
56+
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=[]))
57+
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)
58+
59+
if verbose:
60+
for rewrite_counter in rewrite_counters:
61+
for rewrite, counts in rewrite_counter.items():
62+
print(f"Applied optimization: {rewrite} {counts}x", file=sys.stdout)
63+
64+
new_model = model_from_fgraph(fgraph)
65+
return sample(*args, model=new_model, **kwargs)

0 commit comments

Comments
 (0)