Skip to content

Commit 8de5346

Browse files
ricardoV94jessegrabowski
authored andcommitted
Add Beta-Binomial conjugacy optimization
1 parent cbdb404 commit 8de5346

File tree

8 files changed

+490
-12
lines changed

8 files changed

+490
-12
lines changed

pymc_experimental/sampling/mcmc.py

+38
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,44 @@ def opt_sample(
5252
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
5353
5454
idata = pmx.opt_sample(verbose=True)
55+
56+
# Applied optimization: beta_binomial_conjugacy 1x
57+
# ConjugateRVSampler: [p]
58+
59+
60+
You can control which optimizations are applied using the `include` and `exclude` arguments:
61+
62+
.. code:: python
63+
import pymc as pm
64+
import pymc_experimental as pmx
65+
66+
with pm.Model() as m:
67+
p = pm.Beta("p", 1, 1, shape=(1000,))
68+
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
69+
70+
idata = pmx.opt_sample(exclude="conjugacy", verbose=True)
71+
72+
# No optimizations applied
73+
# NUTS: [p]
74+
75+
.. code:: python
76+
import pymc as pm
77+
import pymc_experimental as pmx
78+
79+
with pm.Model() as m:
80+
a = pm.InverseGamma("a", 1, 1)
81+
b = pm.InverseGamma("b", 1, 1)
82+
p = pm.Beta("p", a, b, shape=(1000,))
83+
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
84+
85+
# By default, the conjugacy of p will not be applied because it depends on other free variables
86+
idata = pmx.opt_sample(include="conjugacy-eager", verbose=True)
87+
88+
# Applied optimization: beta_binomial_conjugacy_eager 1x
89+
# CompoundStep
90+
# >NUTS: [a, b]
91+
# >ConjugateRVSampler: [p]
92+
5593
"""
5694
if kwargs.get("step", None) is not None:
5795
raise ValueError(

pymc_experimental/sampling/optimizations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: F401
22
# Add rewrites to the optimization DBs
3+
import pymc_experimental.sampling.optimizations.conjugacy
34
import pymc_experimental.sampling.optimizations.summary_stats
45

56
from pymc_experimental.sampling.optimizations.optimize import (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from collections.abc import Sequence
2+
from functools import partial
3+
4+
from pymc.distributions import Beta, Binomial
5+
from pymc.model.fgraph import ModelFreeRV, ModelValuedVar, model_free_rv
6+
from pymc.pytensorf import collect_default_updates
7+
from pytensor.graph.basic import Variable, ancestors
8+
from pytensor.graph.fg import FunctionGraph, Output
9+
from pytensor.graph.rewriting.basic import node_rewriter
10+
from pytensor.tensor.elemwise import DimShuffle
11+
from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims
12+
13+
from pymc_experimental.sampling.optimizations.conjugate_sampler import (
14+
ConjugateRV,
15+
)
16+
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db
17+
18+
19+
def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)):
20+
"""Register a rewrite function and its force variant in the posterior optimization DB."""
21+
name = rewrite_fn.__name__
22+
23+
rewrite_fn_default = partial(rewrite_fn, eager=False)
24+
rewrite_fn_default.__name__ = name
25+
rewrite_default = node_rewriter(tracks=tracks)(rewrite_fn_default)
26+
27+
rewrite_fn_eager = partial(rewrite_fn, eager=True)
28+
rewrite_fn_eager.__name__ = f"{name}_eager"
29+
rewrite_eager = node_rewriter(tracks=tracks)(rewrite_fn_eager)
30+
31+
posterior_optimization_db.register(
32+
rewrite_default.__name__,
33+
rewrite_default,
34+
"default",
35+
"conjugacy",
36+
)
37+
38+
posterior_optimization_db.register(
39+
rewrite_eager.__name__,
40+
rewrite_eager,
41+
"non-default",
42+
"conjugacy-eager",
43+
)
44+
45+
return rewrite_default, rewrite_eager
46+
47+
48+
def has_free_rv_ancestor(vars: Variable | Sequence[Variable]) -> bool:
49+
"""Return True if any of the variables have a model variable as an ancestor."""
50+
if not isinstance(vars, Sequence):
51+
vars = (vars,)
52+
53+
# TODO: It should stop at observed RVs, it doesn't matter if they have a free RV above
54+
# Did not implement due to laziness and it being a rare case
55+
return any(
56+
var.owner is not None and isinstance(var.owner.op, ModelFreeRV) for var in ancestors(vars)
57+
)
58+
59+
60+
def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
61+
"""Return the Model dummy var that wraps the RV"""
62+
for client, _ in fgraph.clients[rv]:
63+
if isinstance(client.op, ModelValuedVar):
64+
return client.outputs[0]
65+
66+
67+
def get_dist_params(rv: Variable) -> tuple[Variable]:
68+
return rv.owner.op.dist_params(rv.owner)
69+
70+
71+
def get_size_param(rv: Variable) -> Variable:
72+
return rv.owner.op.size_param(rv.owner)
73+
74+
75+
def rv_used_by(
76+
fgraph: FunctionGraph,
77+
rv: Variable,
78+
used_by_type: type,
79+
used_as_arg_idx: int | Sequence[int],
80+
arg_idx_offset: int = 2, # Ignore the first two arguments (rng and size)
81+
strict: bool = True,
82+
) -> list[Variable]:
83+
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
84+
85+
RV may be used directly or broadcasted before being used.
86+
87+
Parameters
88+
----------
89+
fgraph : FunctionGraph
90+
The function graph containing the RVs
91+
rv : Variable
92+
The RV to check for uses.
93+
used_by_type : type
94+
The type of operation that may use the RV.
95+
used_as_arg_idx : int | Sequence[int]
96+
The index of the RV in the operation's inputs.
97+
strict : bool, default=True
98+
If True, return no results when the RV is used in an unrecognized way.
99+
100+
"""
101+
if isinstance(used_as_arg_idx, int):
102+
used_as_arg_idx = (used_as_arg_idx,)
103+
used_as_arg_idx = tuple(arg_idx + arg_idx_offset for arg_idx in used_as_arg_idx)
104+
105+
clients = fgraph.clients
106+
used_by: list[Variable] = []
107+
for client, inp_idx in clients[rv]:
108+
if isinstance(client.op, Output):
109+
continue
110+
111+
if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
112+
# RV is directly used by the RV type
113+
used_by.append(client.default_output())
114+
115+
elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
116+
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
117+
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
118+
# RV is broadcasted and then used by the RV type
119+
used_by.append(sub_client.default_output())
120+
elif strict:
121+
# Some other unrecognized use, bail out
122+
return []
123+
elif strict:
124+
# Some other unrecognized use, bail out
125+
return []
126+
127+
return used_by
128+
129+
130+
def wrap_rv_and_conjugate_rv(
131+
fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]
132+
) -> Variable:
133+
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
134+
135+
Also takes care of handling the random number generators used in the conjugate posterior.
136+
"""
137+
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
138+
for rng in rngs:
139+
if rng not in fgraph.inputs:
140+
fgraph.add_input(rng)
141+
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
142+
return conjugate_op(rv, *inputs, *rngs)[0]
143+
144+
145+
def create_untransformed_free_rv(
146+
fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]
147+
) -> Variable:
148+
"""Create a model FreeRV without transform."""
149+
transform = None
150+
value = rv.type(name=name)
151+
fgraph.add_input(value)
152+
free_rv = model_free_rv(rv, value, transform, *dims)
153+
free_rv.name = name
154+
return free_rv
155+
156+
157+
def beta_binomial_conjugacy(fgraph: FunctionGraph, node, eager: bool = False):
158+
if not isinstance(node.op, ModelFreeRV):
159+
return None
160+
161+
[beta_free_rv] = node.outputs
162+
beta_rv, _, *beta_dims = node.inputs
163+
164+
if not isinstance(beta_rv.owner.op, Beta):
165+
return None
166+
167+
a, b = get_dist_params(beta_rv)
168+
if not eager and has_free_rv_ancestor([a, b]):
169+
# Don't apply rewrite if a, b depend on other model variables as that will force a Gibbs sampling scheme
170+
return None
171+
172+
p_arg_idx = 1 # Params to the Binomial are (n, p)
173+
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)
174+
175+
if len(binomial_rvs) != 1:
176+
# Question: Can we apply conjugacy when RV is used by more than one binomial?
177+
return None
178+
179+
[binomial_rv] = binomial_rvs
180+
181+
binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
182+
if binomial_model_var is None:
183+
return None
184+
185+
# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
186+
n, _ = get_dist_params(binomial_rv)
187+
188+
# Use value of y in new graph to avoid circularity
189+
y = binomial_model_var.owner.inputs[1]
190+
191+
conjugate_a = sum_bcasted_dims(beta_rv, a + y)
192+
conjugate_b = sum_bcasted_dims(beta_rv, b + (n - y))
193+
194+
conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b, size=get_size_param(beta_rv))
195+
196+
new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
197+
new_beta_free_rv = create_untransformed_free_rv(
198+
fgraph, new_beta_rv, beta_free_rv.name, beta_dims
199+
)
200+
return [new_beta_free_rv]
201+
202+
203+
beta_binomial_conjugacy_default, beta_binomial_conjugacy_force = (
204+
register_conjugacy_rewrites_variants(beta_binomial_conjugacy)
205+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import numpy as np
2+
3+
from pymc import STEP_METHODS
4+
from pymc.distributions.distribution import _support_point
5+
from pymc.initial_point import PointType
6+
from pymc.logprob.abstract import MeasurableOp, _logprob
7+
from pymc.model.core import modelcontext
8+
from pymc.pytensorf import compile_pymc
9+
from pymc.step_methods.compound import BlockedStep, Competence, StepMethodState
10+
from pymc.util import get_value_vars_from_user_vars
11+
from pytensor import shared
12+
from pytensor.compile.builders import OpFromGraph
13+
from pytensor.link.jax.linker import JAXLinker
14+
from pytensor.tensor.random.type import RandomGeneratorType
15+
16+
from pymc_experimental.utils.ofg import inline_ofg_outputs
17+
18+
19+
class ConjugateRV(OpFromGraph, MeasurableOp):
20+
"""Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression.
21+
22+
For partial step samplers to work, the logp and initial point correspond to the original RV
23+
while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the
24+
conjugate posterior expression (i.e., taking forward random draws).
25+
"""
26+
27+
28+
@_logprob.register(ConjugateRV)
29+
def conjugate_rv_logp(op, values, rv, *params, **kwargs):
30+
# Logp is the same as the original RV
31+
return _logprob(rv.owner.op, values, *rv.owner.inputs)
32+
33+
34+
@_support_point.register(ConjugateRV)
35+
def conjugate_rv_support_point(op, conjugate_rv, rv, *params):
36+
# Support point is the same as the original RV
37+
return _support_point(rv.owner.op, rv, *rv.owner.inputs)
38+
39+
40+
class ConjugateRVSampler(BlockedStep):
41+
name = "conjugate_rv_sampler"
42+
_state_class = StepMethodState
43+
44+
def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs):
45+
if len(vars) != 1:
46+
raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time")
47+
48+
model = modelcontext(model)
49+
[value] = get_value_vars_from_user_vars(vars, model=model)
50+
rv = model.values_to_rvs[value]
51+
self.vars = (value,)
52+
self.rv_name = value.name
53+
54+
if model.rvs_to_transforms[rv] is not None:
55+
raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed")
56+
57+
rv_and_posterior_rv_node = rv.owner
58+
op = rv_and_posterior_rv_node.op
59+
if not isinstance(op, ConjugateRV):
60+
raise ValueError("Variable must be a ConjugateRV")
61+
62+
# Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables
63+
value_inputs = model.replace_rvs_by_values(
64+
[rv_and_posterior_rv_node.outputs[1]],
65+
)[0].owner.inputs
66+
# Inline the ConjugateRV graph to only compile `posterior_rv`
67+
_, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs)
68+
69+
if compile_kwargs is None:
70+
compile_kwargs = {}
71+
self.posterior_fn = compile_pymc(
72+
model.value_vars,
73+
posterior_rv,
74+
random_seed=rng,
75+
on_unused_input="ignore",
76+
**compile_kwargs,
77+
)
78+
self.posterior_fn.trust_input = True
79+
if isinstance(self.posterior_fn.maker.linker, JAXLinker):
80+
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
81+
# used internally are not the ones that `function.get_shared()` returns.
82+
raise ValueError("ConjugateRVSampler is not compatible with JAX backend")
83+
84+
def set_rng(self, rng: np.random.Generator):
85+
# Copy the function and replace any shared RNGs
86+
# This is needed so that it can work correctly with multiple traces
87+
# This will be costly if set_rng is called too often!
88+
shared_rngs = [
89+
var
90+
for var in self.posterior_fn.get_shared()
91+
if isinstance(var.type, RandomGeneratorType)
92+
]
93+
n_shared_rngs = len(shared_rngs)
94+
swap = {
95+
old_shared_rng: shared(rng, borrow=True)
96+
for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True)
97+
}
98+
self.posterior_fn = self.posterior_fn.copy(swap=swap)
99+
100+
def step(self, point: PointType) -> tuple[PointType, list]:
101+
new_point = point.copy()
102+
new_point[self.rv_name] = self.posterior_fn(**point)
103+
return new_point, []
104+
105+
@staticmethod
106+
def competence(var, has_grad):
107+
"""BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2."""
108+
if isinstance(var.owner.op, ConjugateRV):
109+
return Competence.IDEAL
110+
111+
return Competence.INCOMPATIBLE
112+
113+
114+
# Register the ConjugateRVSampler
115+
STEP_METHODS.append(ConjugateRVSampler)

0 commit comments

Comments
 (0)