Skip to content

Commit 2525ea8

Browse files
committed
Add Beta-Binomial conjugacy optimization
1 parent 493855f commit 2525ea8

File tree

9 files changed

+410
-18
lines changed

9 files changed

+410
-18
lines changed

pymc_experimental/model/marginal/distributions.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pymc.logprob.abstract import MeasurableOp, _logprob
88
from pymc.logprob.basic import conditional_logp, logp
99
from pymc.pytensorf import constant_fold
10-
from pytensor import Variable
1110
from pytensor.compile.builders import OpFromGraph
1211
from pytensor.compile.mode import Mode
1312
from pytensor.graph import Op, vectorize_graph
@@ -17,6 +16,7 @@
1716
from pytensor.tensor import TensorVariable
1817

1918
from pymc_experimental.distributions import DiscreteMarkovChain
19+
from pymc_experimental.utils.ofg import inline_ofg_outputs
2020

2121

2222
class MarginalRV(OpFromGraph, MeasurableOp):
@@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
126126
return logp.transpose(*dims_alignment)
127127

128128

129-
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
130-
"""Inline the inner graph (outputs) of an OpFromGraph Op.
131-
132-
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133-
the inner graph.
134-
"""
135-
return clone_replace(
136-
op.inner_outputs,
137-
replace=tuple(zip(op.inner_inputs, inputs)),
138-
)
139-
140-
141129
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142130

143131

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# ruff: noqa: F401
22
# Add rewrites to the optimization DBs
3+
import pymc_experimental.sampling.optimizations.conjugacy
34
import pymc_experimental.sampling.optimizations.summary_stats

pymc_experimental/sampling/mcmc.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def opt_sample(
2020
):
2121
"""Sample from a model after applying optimizations.
2222
23+
.. warning:: There is no guarantee that the optimizations will improve the sampling performance. For instance, conjugacy optimizations can lead to less efficient sampling for the remaining variables (if any), due to imposing a Gibbs sampling scheme.
24+
25+
2326
Parameters
2427
----------
2528
model : Model (optional)
@@ -43,17 +46,22 @@ def opt_sample(
4346
import pymc_experimental as pmx
4447
4548
with pm.Model() as m:
46-
p = pm.Beta("p", 1, 1)
47-
y = pm.Binomial("y", n=10, p=p, observed=5)
49+
p = pm.Beta("p", 1, 1, shape=(1000,))
50+
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
4851
4952
idata = pmx.opt_sample(verbose=True)
53+
54+
# Applied optimization: beta_binomial_conjugacy 1x
55+
# ConjugateRVSampler: [p]
5056
"""
5157

5258
model = modelcontext(model)
5359
fgraph, _ = fgraph_from_model(model)
5460

5561
if rewriter is None:
56-
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"]))
62+
rewriter = posterior_optimization_db.query(
63+
RewriteDatabaseQuery(include=["summary_stats", "conjugacy"])
64+
)
5765
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)
5866

5967
if verbose:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from collections.abc import Sequence
2+
3+
from pymc.distributions import Beta, Binomial
4+
from pymc.model.fgraph import ModelFreeRV, ModelValuedVar, model_free_rv
5+
from pymc.pytensorf import collect_default_updates
6+
from pytensor.graph.basic import Variable
7+
from pytensor.graph.fg import FunctionGraph, Output
8+
from pytensor.graph.rewriting.basic import node_rewriter
9+
from pytensor.tensor.elemwise import DimShuffle
10+
11+
from pymc_experimental.sampling.mcmc import posterior_optimization_db
12+
from pymc_experimental.sampling.optimizations.conjugate_sampler import (
13+
ConjugateRV,
14+
)
15+
16+
17+
def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
18+
"""Return the Model dummy var that wraps the RV"""
19+
for client, _ in fgraph.clients[rv]:
20+
if isinstance(client.op, ModelValuedVar):
21+
return client.outputs[0]
22+
23+
24+
def get_dist_params(rv: Variable) -> tuple[Variable]:
25+
return rv.owner.op.dist_params(rv.owner)
26+
27+
28+
def rv_used_by(
29+
fgraph: FunctionGraph,
30+
rv: Variable,
31+
used_by_type: type,
32+
used_as_arg_idx: int | Sequence[int],
33+
strict: bool = True,
34+
) -> list[Variable]:
35+
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
36+
37+
RV may be used directly or broadcasted before being used.
38+
39+
Parameters
40+
----------
41+
fgraph : FunctionGraph
42+
The function graph containing the RVs
43+
rv : Variable
44+
The RV to check for uses.
45+
used_by_type : type
46+
The type of operation that may use the RV.
47+
used_as_arg_idx : int | Sequence[int]
48+
The index of the RV in the operation's inputs.
49+
strict : bool, default=True
50+
If True, return no results when the RV is used in an unrecognized way.
51+
52+
"""
53+
if isinstance(used_as_arg_idx, int):
54+
used_as_arg_idx = (used_as_arg_idx,)
55+
56+
clients = fgraph.clients
57+
used_by: list[Variable] = []
58+
for client, inp_idx in clients[rv]:
59+
if isinstance(client.op, Output):
60+
continue
61+
62+
if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
63+
# RV is directly used by the RV type
64+
used_by.append(client.default_output())
65+
66+
elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
67+
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
68+
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
69+
# RV is broadcasted and then used by the RV type
70+
used_by.append(sub_client.default_output())
71+
elif strict:
72+
# Some other unrecognized use, bail out
73+
return []
74+
elif strict:
75+
# Some other unrecognized use, bail out
76+
return []
77+
78+
return used_by
79+
80+
81+
def wrap_rv_and_conjugate_rv(
82+
fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]
83+
) -> Variable:
84+
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
85+
86+
Also takes care of handling the random number generators used in the conjugate posterior.
87+
"""
88+
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
89+
for rng in rngs:
90+
if rng not in fgraph.inputs:
91+
fgraph.add_input(rng)
92+
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
93+
return conjugate_op(rv, *inputs, *rngs)[0]
94+
95+
96+
def create_untransformed_free_rv(
97+
fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]
98+
) -> Variable:
99+
"""Create a model FreeRV without transform."""
100+
transform = None
101+
value = rv.type(name=name)
102+
fgraph.add_input(value)
103+
free_rv = model_free_rv(rv, value, transform, *dims)
104+
free_rv.name = name
105+
return free_rv
106+
107+
108+
@node_rewriter(tracks=[ModelFreeRV])
109+
def beta_binomial_conjugacy(fgraph: FunctionGraph, node):
110+
[beta_free_rv] = node.outputs
111+
beta_rv, beta_value, *beta_dims = node.inputs
112+
113+
if not isinstance(beta_rv.owner.op, Beta):
114+
return None
115+
116+
p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p)
117+
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)
118+
119+
if len(binomial_rvs) != 1:
120+
# Question: Can we apply conjugacy when RV is used by more than one binomial?
121+
return None
122+
123+
[binomial_rv] = binomial_rvs
124+
125+
binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
126+
if binomial_model_var is None:
127+
return None
128+
129+
# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
130+
a, b = get_dist_params(beta_rv)
131+
n, _ = get_dist_params(binomial_rv)
132+
133+
# Use value of y in new graph to avoid circularity
134+
y = binomial_model_var.owner.inputs[1]
135+
136+
conjugate_a = a + y
137+
conjugate_b = b + (n - y)
138+
extra_dims = range(binomial_rv.type.ndim - beta_rv.type.ndim)
139+
if extra_dims:
140+
conjugate_a = conjugate_a.sum(extra_dims)
141+
conjugate_b = conjugate_b.sum(extra_dims)
142+
conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b)
143+
144+
new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
145+
new_beta_free_rv = create_untransformed_free_rv(
146+
fgraph, new_beta_rv, beta_free_rv.name, beta_dims
147+
)
148+
return [new_beta_free_rv]
149+
150+
151+
posterior_optimization_db.register(
152+
beta_binomial_conjugacy.__name__, beta_binomial_conjugacy, "conjugacy"
153+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import numpy as np
2+
3+
from pymc import STEP_METHODS
4+
from pymc.distributions.distribution import _support_point
5+
from pymc.initial_point import PointType
6+
from pymc.logprob.abstract import MeasurableOp, _logprob
7+
from pymc.model.core import modelcontext
8+
from pymc.pytensorf import compile_pymc
9+
from pymc.step_methods.compound import BlockedStep, Competence, StepMethodState
10+
from pymc.util import get_value_vars_from_user_vars
11+
from pytensor import shared
12+
from pytensor.compile.builders import OpFromGraph
13+
from pytensor.link.jax.linker import JAXLinker
14+
from pytensor.tensor.random.type import RandomGeneratorType
15+
16+
from pymc_experimental.utils.ofg import inline_ofg_outputs
17+
18+
19+
class ConjugateRV(OpFromGraph, MeasurableOp):
20+
"""Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression.
21+
22+
For partial step samplers to work, the logp and initial point correspond to the original RV
23+
while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the
24+
conjugate posterior expression (i.e., taking forward random draws).
25+
"""
26+
27+
28+
@_logprob.register(ConjugateRV)
29+
def conjugate_rv_logp(op, values, rv, *params, **kwargs):
30+
# Logp is the same as the original RV
31+
return _logprob(rv.owner.op, values, *rv.owner.inputs)
32+
33+
34+
@_support_point.register(ConjugateRV)
35+
def conjugate_rv_support_point(op, conjugate_rv, rv, *params):
36+
# Support point is the same as the original RV
37+
return _support_point(rv.owner.op, rv, *rv.owner.inputs)
38+
39+
40+
class ConjugateRVSampler(BlockedStep):
41+
name = "conjugate_rv_sampler"
42+
_state_class = StepMethodState
43+
44+
def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs):
45+
if len(vars) != 1:
46+
raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time")
47+
48+
model = modelcontext(model)
49+
[value] = get_value_vars_from_user_vars(vars, model=model)
50+
rv = model.values_to_rvs[value]
51+
self.vars = (value,)
52+
self.rv_name = value.name
53+
54+
if model.rvs_to_transforms[rv] is not None:
55+
raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed")
56+
57+
rv_and_posterior_rv_node = rv.owner
58+
op = rv_and_posterior_rv_node.op
59+
if not isinstance(op, ConjugateRV):
60+
raise ValueError("Variable must be a ConjugateRV")
61+
62+
# Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables
63+
value_inputs = model.replace_rvs_by_values(
64+
[rv_and_posterior_rv_node.outputs[1]],
65+
)[0].owner.inputs
66+
# Inline the ConjugateRV graph to only compile `posterior_rv`
67+
_, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs)
68+
69+
if compile_kwargs is None:
70+
compile_kwargs = {}
71+
self.posterior_fn = compile_pymc(
72+
model.value_vars,
73+
posterior_rv,
74+
random_seed=rng,
75+
on_unused_input="ignore",
76+
**compile_kwargs,
77+
)
78+
self.posterior_fn.trust_input = True
79+
if isinstance(self.posterior_fn.maker.linker, JAXLinker):
80+
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
81+
# used internally are not the ones that `function.get_shared()` returns.
82+
raise ValueError("ConjugateRVSampler is not compatible with JAX backend")
83+
84+
def set_rng(self, rng: np.random.Generator):
85+
# Copy the function and replace any shared RNGs
86+
# This is needed so that it can work correctly with multiple traces
87+
# This will be costly if set_rng is called too often!
88+
shared_rngs = [
89+
var
90+
for var in self.posterior_fn.get_shared()
91+
if isinstance(var.type, RandomGeneratorType)
92+
]
93+
n_shared_rngs = len(shared_rngs)
94+
swap = {
95+
old_shared_rng: shared(rng, borrow=True)
96+
for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True)
97+
}
98+
self.posterior_fn = self.posterior_fn.copy(swap=swap)
99+
100+
def step(self, point: PointType) -> tuple[PointType, list]:
101+
new_point = point.copy()
102+
new_point[self.rv_name] = self.posterior_fn(**point)
103+
return new_point, []
104+
105+
@staticmethod
106+
def competence(var, has_grad):
107+
"""BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2."""
108+
if isinstance(var.owner.op, ConjugateRV):
109+
return Competence.IDEAL
110+
111+
return Competence.INCOMPATIBLE
112+
113+
114+
# Register the ConjugateRVSampler
115+
STEP_METHODS.append(ConjugateRVSampler)

pymc_experimental/utils/ofg.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from collections.abc import Sequence
2+
3+
from pytensor.compile.builders import OpFromGraph
4+
from pytensor.graph.basic import Variable
5+
from pytensor.graph.replace import clone_replace
6+
7+
8+
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
9+
"""Inline the inner graph (outputs) of an OpFromGraph Op.
10+
11+
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
12+
the inner graph.
13+
"""
14+
return clone_replace(
15+
op.inner_outputs,
16+
replace=tuple(zip(op.inner_inputs, inputs)),
17+
)

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ addopts = [
88
]
99

1010
filterwarnings =[
11-
"error",
11+
# "error",
1212
# Raised by arviz when the model_builder class adds non-standard group names to InferenceData
1313
"ignore::UserWarning:arviz.data.inference_data",
1414

0 commit comments

Comments
 (0)