|
| 1 | +import sys |
| 2 | + |
| 3 | +from pymc.model.core import Model |
| 4 | +from pymc.sampling.mcmc import sample |
| 5 | +from pytensor.graph.rewriting.basic import GraphRewriter |
| 6 | + |
| 7 | +from pymc_experimental.sampling.optimizations.optimize import ( |
| 8 | + TAGS_TYPE, |
| 9 | + optimize_model_for_mcmc_sampling, |
| 10 | +) |
| 11 | + |
| 12 | + |
| 13 | +def opt_sample( |
| 14 | + *args, |
| 15 | + model: Model | None = None, |
| 16 | + include: TAGS_TYPE = ("default",), |
| 17 | + exclude: TAGS_TYPE = None, |
| 18 | + rewriter: GraphRewriter | None = None, |
| 19 | + verbose: bool = False, |
| 20 | + **kwargs, |
| 21 | +): |
| 22 | + """Sample from a model after applying optimizations. |
| 23 | +
|
| 24 | + Parameters |
| 25 | + ---------- |
| 26 | + model : Model, optinoal |
| 27 | + The model to sample from. If None, use the model associated with the context. |
| 28 | + include : TAGS_TYPE |
| 29 | + The tags to include in the optimizations. Ignored if `rewriter` is not None. |
| 30 | + exclude : TAGS_TYPE |
| 31 | + The tags to exclude from the optimizations. Ignored if `rewriter` is not None. |
| 32 | + rewriter : RewriteDatabaseQuery (optional) |
| 33 | + The rewriter to use. If None, use the default rewriter with the given `include` and `exclude` tags. |
| 34 | + verbose : bool, default=False |
| 35 | + Print information about the optimizations applied. |
| 36 | + *args, **kwargs: |
| 37 | + Passed to `pm.sample` |
| 38 | +
|
| 39 | + Returns |
| 40 | + ------- |
| 41 | + sample_output: |
| 42 | + The output of `pm.sample` |
| 43 | +
|
| 44 | + Examples |
| 45 | + -------- |
| 46 | + .. code:: python |
| 47 | + import pymc as pm |
| 48 | + import pymc_experimental as pmx |
| 49 | +
|
| 50 | + with pm.Model() as m: |
| 51 | + p = pm.Beta("p", 1, 1, shape=(1000,)) |
| 52 | + y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) |
| 53 | +
|
| 54 | + idata = pmx.opt_sample(verbose=True) |
| 55 | + """ |
| 56 | + if kwargs.get("step", None) is not None: |
| 57 | + raise ValueError( |
| 58 | + "The `step` argument is not supported in `opt_sample`, as custom steps would refer to the original model.\n" |
| 59 | + "You can manually transform the model with `pymc_experimental.sampling.optimizations.optimize_model_for_mcmc_sampling` " |
| 60 | + "and then define the custom steps and forward them to `pymc.sample`." |
| 61 | + ) |
| 62 | + |
| 63 | + opt_model, rewrite_counters = optimize_model_for_mcmc_sampling( |
| 64 | + model, include=include, exclude=exclude, rewriter=rewriter |
| 65 | + ) |
| 66 | + |
| 67 | + if verbose: |
| 68 | + applied_opt = False |
| 69 | + for rewrite_counter in rewrite_counters: |
| 70 | + for rewrite, counts in rewrite_counter.items(): |
| 71 | + applied_opt = True |
| 72 | + print(f"Applied optimization: {rewrite} {counts}x", file=sys.stdout) |
| 73 | + if not applied_opt: |
| 74 | + print("No optimizations applied", file=sys.stdout) |
| 75 | + |
| 76 | + return sample(*args, model=opt_model, **kwargs) |
0 commit comments