Skip to content

Commit 7bdb03e

Browse files
committed
Add Beta-Binomial conjugacy optimization
1 parent 1469915 commit 7bdb03e

File tree

8 files changed

+490
-13
lines changed

8 files changed

+490
-13
lines changed

pymc_experimental/model/marginal/distributions.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pymc.logprob.abstract import MeasurableOp, _logprob
88
from pymc.logprob.basic import conditional_logp, logp
99
from pymc.pytensorf import constant_fold
10-
from pytensor import Variable
1110
from pytensor.compile.builders import OpFromGraph
1211
from pytensor.compile.mode import Mode
1312
from pytensor.graph import Op, vectorize_graph
@@ -17,6 +16,7 @@
1716
from pytensor.tensor import TensorVariable
1817

1918
from pymc_experimental.distributions import DiscreteMarkovChain
19+
from pymc_experimental.utils.ofg import inline_ofg_outputs
2020

2121

2222
class MarginalRV(OpFromGraph, MeasurableOp):
@@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
126126
return logp.transpose(*dims_alignment)
127127

128128

129-
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
130-
"""Inline the inner graph (outputs) of an OpFromGraph Op.
131-
132-
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133-
the inner graph.
134-
"""
135-
return clone_replace(
136-
op.inner_outputs,
137-
replace=tuple(zip(op.inner_inputs, inputs)),
138-
)
139-
140-
141129
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142130

143131

pymc_experimental/sampling/mcmc.py

+38
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,44 @@ def opt_sample(
5252
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
5353
5454
idata = pmx.opt_sample(verbose=True)
55+
56+
# Applied optimization: beta_binomial_conjugacy 1x
57+
# ConjugateRVSampler: [p]
58+
59+
60+
You can control which optimizations are applied using the `include` and `exclude` arguments:
61+
62+
.. code:: python
63+
import pymc as pm
64+
import pymc_experimental as pmx
65+
66+
with pm.Model() as m:
67+
p = pm.Beta("p", 1, 1, shape=(1000,))
68+
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
69+
70+
idata = pmx.opt_sample(exclude="conjugacy", verbose=True)
71+
72+
# No optimizations applied
73+
# NUTS: [p]
74+
75+
.. code:: python
76+
import pymc as pm
77+
import pymc_experimental as pmx
78+
79+
with pm.Model() as m:
80+
a = pm.InverseGamma("a", 1, 1)
81+
b = pm.InverseGamma("b", 1, 1)
82+
p = pm.Beta("p", a, b, shape=(1000,))
83+
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
84+
85+
# By default, the conjugacy of p will not be applied because it depends on other free variables
86+
idata = pmx.opt_sample(include="conjugacy-eager", verbose=True)
87+
88+
# Applied optimization: beta_binomial_conjugacy_eager 1x
89+
# CompoundStep
90+
# >NUTS: [a, b]
91+
# >ConjugateRVSampler: [p]
92+
5593
"""
5694
if kwargs.get("step", None) is not None:
5795
raise ValueError(

pymc_experimental/sampling/optimizations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: F401
22
# Add rewrites to the optimization DBs
3+
import pymc_experimental.sampling.optimizations.conjugacy
34
import pymc_experimental.sampling.optimizations.summary_stats
45

56
from pymc_experimental.sampling.optimizations.optimize import (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from collections.abc import Sequence
2+
from functools import partial
3+
4+
from pymc.distributions import Beta, Binomial
5+
from pymc.model.fgraph import ModelFreeRV, ModelValuedVar, model_free_rv
6+
from pymc.pytensorf import collect_default_updates
7+
from pytensor.graph.basic import Variable, ancestors
8+
from pytensor.graph.fg import FunctionGraph, Output
9+
from pytensor.graph.rewriting.basic import node_rewriter
10+
from pytensor.tensor.elemwise import DimShuffle
11+
from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims
12+
13+
from pymc_experimental.sampling.optimizations.conjugate_sampler import (
14+
ConjugateRV,
15+
)
16+
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db
17+
18+
19+
def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)):
20+
"""Register a rewrite function and its force variant in the posterior optimization DB."""
21+
name = rewrite_fn.__name__
22+
23+
rewrite_fn_default = partial(rewrite_fn, eager=False)
24+
rewrite_fn_default.__name__ = name
25+
rewrite_default = node_rewriter(tracks=tracks)(rewrite_fn_default)
26+
27+
rewrite_fn_eager = partial(rewrite_fn, eager=True)
28+
rewrite_fn_eager.__name__ = f"{name}_eager"
29+
rewrite_eager = node_rewriter(tracks=tracks)(rewrite_fn_eager)
30+
31+
posterior_optimization_db.register(
32+
rewrite_default.__name__,
33+
rewrite_default,
34+
"default",
35+
"conjugacy",
36+
)
37+
38+
posterior_optimization_db.register(
39+
rewrite_eager.__name__,
40+
rewrite_eager,
41+
"non-default",
42+
"conjugacy-eager",
43+
)
44+
45+
return rewrite_default, rewrite_eager
46+
47+
48+
def has_free_rv_ancestor(vars: Variable | Sequence[Variable]) -> bool:
49+
"""Return True if any of the variables have a model variable as an ancestor."""
50+
if not isinstance(vars, Sequence):
51+
vars = (vars,)
52+
53+
# TODO: It should stop at observed RVs, it doesn't matter if they have a free RV above
54+
# Did not implement due to laziness and it being a rare case
55+
return any(
56+
var.owner is not None and isinstance(var.owner.op, ModelFreeRV) for var in ancestors(vars)
57+
)
58+
59+
60+
def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
61+
"""Return the Model dummy var that wraps the RV"""
62+
for client, _ in fgraph.clients[rv]:
63+
if isinstance(client.op, ModelValuedVar):
64+
return client.outputs[0]
65+
66+
67+
def get_dist_params(rv: Variable) -> tuple[Variable]:
68+
return rv.owner.op.dist_params(rv.owner)
69+
70+
71+
def get_size_param(rv: Variable) -> Variable:
72+
return rv.owner.op.size_param(rv.owner)
73+
74+
75+
def rv_used_by(
76+
fgraph: FunctionGraph,
77+
rv: Variable,
78+
used_by_type: type,
79+
used_as_arg_idx: int | Sequence[int],
80+
arg_idx_offset: int = 2, # Ignore the first two arguments (rng and size)
81+
strict: bool = True,
82+
) -> list[Variable]:
83+
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
84+
85+
RV may be used directly or broadcasted before being used.
86+
87+
Parameters
88+
----------
89+
fgraph : FunctionGraph
90+
The function graph containing the RVs
91+
rv : Variable
92+
The RV to check for uses.
93+
used_by_type : type
94+
The type of operation that may use the RV.
95+
used_as_arg_idx : int | Sequence[int]
96+
The index of the RV in the operation's inputs.
97+
strict : bool, default=True
98+
If True, return no results when the RV is used in an unrecognized way.
99+
100+
"""
101+
if isinstance(used_as_arg_idx, int):
102+
used_as_arg_idx = (used_as_arg_idx,)
103+
used_as_arg_idx = tuple(arg_idx + arg_idx_offset for arg_idx in used_as_arg_idx)
104+
105+
clients = fgraph.clients
106+
used_by: list[Variable] = []
107+
for client, inp_idx in clients[rv]:
108+
if isinstance(client.op, Output):
109+
continue
110+
111+
if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
112+
# RV is directly used by the RV type
113+
used_by.append(client.default_output())
114+
115+
elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
116+
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
117+
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
118+
# RV is broadcasted and then used by the RV type
119+
used_by.append(sub_client.default_output())
120+
elif strict:
121+
# Some other unrecognized use, bail out
122+
return []
123+
elif strict:
124+
# Some other unrecognized use, bail out
125+
return []
126+
127+
return used_by
128+
129+
130+
def wrap_rv_and_conjugate_rv(
131+
fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]
132+
) -> Variable:
133+
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
134+
135+
Also takes care of handling the random number generators used in the conjugate posterior.
136+
"""
137+
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
138+
for rng in rngs:
139+
if rng not in fgraph.inputs:
140+
fgraph.add_input(rng)
141+
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
142+
return conjugate_op(rv, *inputs, *rngs)[0]
143+
144+
145+
def create_untransformed_free_rv(
146+
fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]
147+
) -> Variable:
148+
"""Create a model FreeRV without transform."""
149+
transform = None
150+
value = rv.type(name=name)
151+
fgraph.add_input(value)
152+
free_rv = model_free_rv(rv, value, transform, *dims)
153+
free_rv.name = name
154+
return free_rv
155+
156+
157+
def beta_binomial_conjugacy(fgraph: FunctionGraph, node, eager: bool = False):
158+
if not isinstance(node.op, ModelFreeRV):
159+
return None
160+
161+
[beta_free_rv] = node.outputs
162+
beta_rv, _, *beta_dims = node.inputs
163+
164+
if not isinstance(beta_rv.owner.op, Beta):
165+
return None
166+
167+
a, b = get_dist_params(beta_rv)
168+
if not eager and has_free_rv_ancestor([a, b]):
169+
# Don't apply rewrite if a, b depend on other model variables as that will force a Gibbs sampling scheme
170+
return None
171+
172+
p_arg_idx = 1 # Params to the Binomial are (n, p)
173+
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)
174+
175+
if len(binomial_rvs) != 1:
176+
# Question: Can we apply conjugacy when RV is used by more than one binomial?
177+
return None
178+
179+
[binomial_rv] = binomial_rvs
180+
181+
binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
182+
if binomial_model_var is None:
183+
return None
184+
185+
# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
186+
n, _ = get_dist_params(binomial_rv)
187+
188+
# Use value of y in new graph to avoid circularity
189+
y = binomial_model_var.owner.inputs[1]
190+
191+
conjugate_a = sum_bcasted_dims(beta_rv, a + y)
192+
conjugate_b = sum_bcasted_dims(beta_rv, b + (n - y))
193+
194+
conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b, size=get_size_param(beta_rv))
195+
196+
new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
197+
new_beta_free_rv = create_untransformed_free_rv(
198+
fgraph, new_beta_rv, beta_free_rv.name, beta_dims
199+
)
200+
return [new_beta_free_rv]
201+
202+
203+
beta_binomial_conjugacy_default, beta_binomial_conjugacy_force = (
204+
register_conjugacy_rewrites_variants(beta_binomial_conjugacy)
205+
)

0 commit comments

Comments
 (0)