Skip to content

Commit 575beee

Browse files
committed
Add Beta-Binomial conjugacy optimization
1 parent b239e01 commit 575beee

File tree

8 files changed

+487
-13
lines changed

8 files changed

+487
-13
lines changed

pymc_experimental/model/marginal/distributions.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pymc.logprob.abstract import MeasurableOp, _logprob
88
from pymc.logprob.basic import conditional_logp, logp
99
from pymc.pytensorf import constant_fold
10-
from pytensor import Variable
1110
from pytensor.compile.builders import OpFromGraph
1211
from pytensor.compile.mode import Mode
1312
from pytensor.graph import Op, vectorize_graph
@@ -17,6 +16,7 @@
1716
from pytensor.tensor import TensorVariable
1817

1918
from pymc_experimental.distributions import DiscreteMarkovChain
19+
from pymc_experimental.utils.ofg import inline_ofg_outputs
2020

2121

2222
class MarginalRV(OpFromGraph, MeasurableOp):
@@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
126126
return logp.transpose(*dims_alignment)
127127

128128

129-
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
130-
"""Inline the inner graph (outputs) of an OpFromGraph Op.
131-
132-
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133-
the inner graph.
134-
"""
135-
return clone_replace(
136-
op.inner_outputs,
137-
replace=tuple(zip(op.inner_inputs, inputs)),
138-
)
139-
140-
141129
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142130

143131

pymc_experimental/sampling/mcmc.py

+38
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,44 @@ def opt_sample(
5252
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
5353
5454
idata = pmx.opt_sample(verbose=True)
55+
56+
# Applied optimization: beta_binomial_conjugacy 1x
57+
# ConjugateRVSampler: [p]
58+
59+
60+
You can control which optimizations are applied using the `include` and `exclude` arguments:
61+
62+
.. code:: python
63+
import pymc as pm
64+
import pymc_experimental as pmx
65+
66+
with pm.Model() as m:
67+
p = pm.Beta("p", 1, 1, shape=(1000,))
68+
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
69+
70+
idata = pmx.opt_sample(exclude="conjugacy", verbose=True)
71+
72+
# No optimizations applied
73+
# NUTS: [p]
74+
75+
.. code:: python
76+
import pymc as pm
77+
import pymc_experimental as pmx
78+
79+
with pm.Model() as m:
80+
a = pm.InverseGamma("a", 1, 1)
81+
b = pm.InverseGamma("b", 1, 1)
82+
p = pm.Beta("p", a, b, shape=(1000,))
83+
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
84+
85+
# By default, the conjugacy of p will not be applied because it depends on other free variables
86+
idata = pmx.opt_sample(include="conjugacy-eager", verbose=True)
87+
88+
# Applied optimization: beta_binomial_conjugacy_eager 1x
89+
# CompoundStep
90+
# >NUTS: [a, b]
91+
# >ConjugateRVSampler: [p]
92+
5593
"""
5694
if kwargs.get("step", None) is not None:
5795
raise ValueError(

pymc_experimental/sampling/optimizations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: F401
22
# Add rewrites to the optimization DBs
3+
import pymc_experimental.sampling.optimizations.conjugacy
34
import pymc_experimental.sampling.optimizations.summary_stats
45

56
from pymc_experimental.sampling.optimizations.optimize import (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from collections.abc import Sequence
2+
from functools import partial
3+
4+
from pymc.distributions import Beta, Binomial
5+
from pymc.model.fgraph import ModelFreeRV, ModelValuedVar, model_free_rv
6+
from pymc.pytensorf import collect_default_updates
7+
from pytensor.graph.basic import Variable, ancestors
8+
from pytensor.graph.fg import FunctionGraph, Output
9+
from pytensor.graph.rewriting.basic import node_rewriter
10+
from pytensor.tensor.elemwise import DimShuffle
11+
12+
from pymc_experimental.sampling.optimizations.conjugate_sampler import (
13+
ConjugateRV,
14+
)
15+
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db
16+
17+
18+
def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)):
19+
"""Register a rewrite function and its force variant in the posterior optimization DB."""
20+
name = rewrite_fn.__name__
21+
22+
rewrite_fn_default = partial(rewrite_fn, eager=False)
23+
rewrite_fn_default.__name__ = name
24+
rewrite_default = node_rewriter(tracks=tracks)(rewrite_fn_default)
25+
26+
rewrite_fn_eager = partial(rewrite_fn, eager=True)
27+
rewrite_fn_eager.__name__ = f"{name}_eager"
28+
rewrite_eager = node_rewriter(tracks=tracks)(rewrite_fn_eager)
29+
30+
posterior_optimization_db.register(
31+
rewrite_default.__name__,
32+
rewrite_default,
33+
"default",
34+
"conjugacy",
35+
)
36+
37+
posterior_optimization_db.register(
38+
rewrite_eager.__name__,
39+
rewrite_eager,
40+
"non-default",
41+
"conjugacy-eager",
42+
)
43+
44+
return rewrite_default, rewrite_eager
45+
46+
47+
def has_free_rv_ancestor(vars: Variable | Sequence[Variable]) -> bool:
48+
"""Return True if any of the variables have a model variable as an ancestor."""
49+
if not isinstance(vars, Sequence):
50+
vars = (vars,)
51+
52+
# TODO: It should stop at observed RVs, it doesn't matter if they have a free RV above
53+
# Did not implement due to laziness and it being a rare case
54+
return any(
55+
var.owner is not None and isinstance(var.owner.op, ModelFreeRV) for var in ancestors(vars)
56+
)
57+
58+
59+
def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
60+
"""Return the Model dummy var that wraps the RV"""
61+
for client, _ in fgraph.clients[rv]:
62+
if isinstance(client.op, ModelValuedVar):
63+
return client.outputs[0]
64+
65+
66+
def get_dist_params(rv: Variable) -> tuple[Variable]:
67+
return rv.owner.op.dist_params(rv.owner)
68+
69+
70+
def rv_used_by(
71+
fgraph: FunctionGraph,
72+
rv: Variable,
73+
used_by_type: type,
74+
used_as_arg_idx: int | Sequence[int],
75+
strict: bool = True,
76+
) -> list[Variable]:
77+
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
78+
79+
RV may be used directly or broadcasted before being used.
80+
81+
Parameters
82+
----------
83+
fgraph : FunctionGraph
84+
The function graph containing the RVs
85+
rv : Variable
86+
The RV to check for uses.
87+
used_by_type : type
88+
The type of operation that may use the RV.
89+
used_as_arg_idx : int | Sequence[int]
90+
The index of the RV in the operation's inputs.
91+
strict : bool, default=True
92+
If True, return no results when the RV is used in an unrecognized way.
93+
94+
"""
95+
if isinstance(used_as_arg_idx, int):
96+
used_as_arg_idx = (used_as_arg_idx,)
97+
98+
clients = fgraph.clients
99+
used_by: list[Variable] = []
100+
for client, inp_idx in clients[rv]:
101+
if isinstance(client.op, Output):
102+
continue
103+
104+
if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
105+
# RV is directly used by the RV type
106+
used_by.append(client.default_output())
107+
108+
elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
109+
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
110+
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
111+
# RV is broadcasted and then used by the RV type
112+
used_by.append(sub_client.default_output())
113+
elif strict:
114+
# Some other unrecognized use, bail out
115+
return []
116+
elif strict:
117+
# Some other unrecognized use, bail out
118+
return []
119+
120+
return used_by
121+
122+
123+
def wrap_rv_and_conjugate_rv(
124+
fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]
125+
) -> Variable:
126+
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
127+
128+
Also takes care of handling the random number generators used in the conjugate posterior.
129+
"""
130+
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
131+
for rng in rngs:
132+
if rng not in fgraph.inputs:
133+
fgraph.add_input(rng)
134+
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
135+
return conjugate_op(rv, *inputs, *rngs)[0]
136+
137+
138+
def create_untransformed_free_rv(
139+
fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]
140+
) -> Variable:
141+
"""Create a model FreeRV without transform."""
142+
transform = None
143+
value = rv.type(name=name)
144+
fgraph.add_input(value)
145+
free_rv = model_free_rv(rv, value, transform, *dims)
146+
free_rv.name = name
147+
return free_rv
148+
149+
150+
def beta_binomial_conjugacy(fgraph: FunctionGraph, node, eager: bool = False):
151+
if not isinstance(node.op, ModelFreeRV):
152+
return None
153+
154+
[beta_free_rv] = node.outputs
155+
beta_rv, beta_value, *beta_dims = node.inputs
156+
157+
if not isinstance(beta_rv.owner.op, Beta):
158+
return None
159+
160+
a, b = get_dist_params(beta_rv)
161+
if not eager and has_free_rv_ancestor([a, b]):
162+
# Don't apply rewrite if a, b depend on other model variables as that will force a Gibbs sampling scheme
163+
return None
164+
165+
p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p)
166+
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)
167+
168+
if len(binomial_rvs) != 1:
169+
# Question: Can we apply conjugacy when RV is used by more than one binomial?
170+
return None
171+
172+
[binomial_rv] = binomial_rvs
173+
174+
binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
175+
if binomial_model_var is None:
176+
return None
177+
178+
# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
179+
n, _ = get_dist_params(binomial_rv)
180+
181+
# Use value of y in new graph to avoid circularity
182+
y = binomial_model_var.owner.inputs[1]
183+
184+
conjugate_a = a + y
185+
conjugate_b = b + (n - y)
186+
extra_dims = range(binomial_rv.type.ndim - beta_rv.type.ndim)
187+
if extra_dims:
188+
conjugate_a = conjugate_a.sum(extra_dims)
189+
conjugate_b = conjugate_b.sum(extra_dims)
190+
conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b)
191+
192+
new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
193+
new_beta_free_rv = create_untransformed_free_rv(
194+
fgraph, new_beta_rv, beta_free_rv.name, beta_dims
195+
)
196+
return [new_beta_free_rv]
197+
198+
199+
beta_binomial_conjugacy_default, beta_binomial_conjugacy_force = (
200+
register_conjugacy_rewrites_variants(beta_binomial_conjugacy)
201+
)

0 commit comments

Comments
 (0)