Skip to content

Commit dbe62e8

Browse files
committed
Add Beta-Binomial conjugacy optimization
1 parent e051965 commit dbe62e8

File tree

9 files changed

+368
-16
lines changed

9 files changed

+368
-16
lines changed

pymc_experimental/model/marginal/distributions.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pymc.logprob.abstract import MeasurableOp, _logprob
88
from pymc.logprob.basic import conditional_logp, logp
99
from pymc.pytensorf import constant_fold
10-
from pytensor import Variable
1110
from pytensor.compile.builders import OpFromGraph
1211
from pytensor.compile.mode import Mode
1312
from pytensor.graph import Op, vectorize_graph
@@ -17,6 +16,7 @@
1716
from pytensor.tensor import TensorVariable
1817

1918
from pymc_experimental.distributions import DiscreteMarkovChain
19+
from pymc_experimental.utils.ofg import inline_ofg_outputs
2020

2121

2222
class MarginalRV(OpFromGraph, MeasurableOp):
@@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
126126
return logp.transpose(*dims_alignment)
127127

128128

129-
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
130-
"""Inline the inner graph (outputs) of an OpFromGraph Op.
131-
132-
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133-
the inner graph.
134-
"""
135-
return clone_replace(
136-
op.inner_outputs,
137-
replace=tuple(zip(op.inner_inputs, inputs)),
138-
)
139-
140-
141129
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142130

143131

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# Add rewrites to the optimization DBs
2+
import pymc_experimental.sampling.optimizations.conjugacy
23
import pymc_experimental.sampling.optimizations.summary_stats

pymc_experimental/sampling/mcmc.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def opt_sample(
2020
):
2121
"""Sample from a model after applying optimizations.
2222
23+
.. warning:: There is no guarantee that the optimizations will improve the sampling performance. For instance, conjugacy optimizations can lead to less efficient sampling for the remaining variables (if any), due to imposing a Gibbs sampling scheme.
24+
25+
2326
Parameters
2427
----------
2528
model : Model (optional)
@@ -47,13 +50,16 @@ def opt_sample(
4750
y = pm.Binomial("y", n=10, p=p, observed=5)
4851
4952
idata = pmx.opt_sample(verbose=True)
53+
54+
# Applied optimization: beta_binomial_conjugacy 1x
55+
# ConjugateRVSampler: [p]
5056
"""
5157

5258
model = modelcontext(model)
5359
fgraph, _ = fgraph_from_model(model)
5460

5561
if rewriter is None:
56-
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"]))
62+
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats", "conjugacy"]))
5763
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)
5864

5965
if verbose:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from typing import Sequence
2+
3+
from pymc import STEP_METHODS
4+
from pytensor.tensor.random.type import RandomGeneratorType
5+
6+
from pytensor.compile.builders import OpFromGraph
7+
8+
from pymc_experimental.sampling.mcmc import posterior_optimization_db
9+
from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV, ConjugateRVSampler
10+
11+
STEP_METHODS.append(ConjugateRVSampler)
12+
13+
from pytensor.graph.fg import Output
14+
from pytensor.tensor.elemwise import DimShuffle
15+
from pymc.model.fgraph import model_free_rv, ModelValuedVar
16+
17+
18+
from pytensor.graph.basic import Variable
19+
from pytensor.graph.fg import FunctionGraph
20+
from pytensor.graph.rewriting.basic import node_rewriter
21+
from pymc.model.fgraph import ModelFreeRV
22+
from pymc.distributions import Beta, Binomial
23+
from pymc.pytensorf import collect_default_updates
24+
25+
26+
def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
27+
"""Return the Model dummy var that wraps the RV"""
28+
for client, _ in fgraph.clients[rv]:
29+
if isinstance(client.op, ModelValuedVar):
30+
return client.outputs[0]
31+
32+
33+
def get_dist_params(rv: Variable) -> tuple[Variable]:
34+
return rv.owner.op.dist_params(rv.owner)
35+
36+
37+
def rv_used_by(fgraph: FunctionGraph, rv: Variable, used_by_type: type, used_as_arg_idx: int | Sequence[int], strict: bool = True) -> list[Variable]:
38+
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
39+
40+
RV may be used directly or broadcasted before being used.
41+
42+
Parameters
43+
----------
44+
fgraph : FunctionGraph
45+
The function graph containing the RVs
46+
rv : Variable
47+
The RV to check for uses.
48+
used_by_type : type
49+
The type of operation that may use the RV.
50+
used_as_arg_idx : int | Sequence[int]
51+
The index of the RV in the operation's inputs.
52+
strict : bool, default=True
53+
If True, return no results when the RV is used in an unrecognized way.
54+
55+
"""
56+
if isinstance(used_as_arg_idx, int):
57+
used_as_arg_idx = (used_as_arg_idx,)
58+
59+
clients = fgraph.clients
60+
used_by : list[Variable] = []
61+
for client, inp_idx in clients[rv]:
62+
if isinstance(client.op, Output):
63+
continue
64+
65+
if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
66+
# RV is directly used by the RV type
67+
used_by.append(client.default_output())
68+
69+
elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
70+
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
71+
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
72+
# RV is broadcasted and then used by the RV type
73+
used_by.append(sub_client.default_output())
74+
elif strict:
75+
# Some other unrecognized use, bail out
76+
return []
77+
elif strict:
78+
# Some other unrecognized use, bail out
79+
return []
80+
81+
return used_by
82+
83+
84+
def wrap_rv_and_conjugate_rv(fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]) -> Variable:
85+
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
86+
87+
Also takes care of handling the random number generators used in the conjugate posterior.
88+
"""
89+
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
90+
for rng in rngs:
91+
if rng not in fgraph.inputs:
92+
fgraph.add_input(rng)
93+
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
94+
return conjugate_op(rv, *inputs, *rngs)[0]
95+
96+
97+
def create_untransformed_free_rv(fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]) -> Variable:
98+
"""Create a model FreeRV without transform."""
99+
transform = None
100+
value = rv.type(name=name)
101+
fgraph.add_input(value)
102+
free_rv = model_free_rv(rv, value, transform, *dims)
103+
free_rv.name = name
104+
return free_rv
105+
106+
107+
@node_rewriter(tracks=[ModelFreeRV])
108+
def beta_binomial_conjugacy(fgraph: FunctionGraph, node):
109+
"""This applies the equivalence (up to a normalizing constant) described in:
110+
111+
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
112+
"""
113+
[beta_free_rv] = node.outputs
114+
beta_rv, beta_value, *beta_dims = node.inputs
115+
116+
if not isinstance(beta_rv.owner.op, Beta):
117+
return None
118+
119+
p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p)
120+
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)
121+
122+
if len(binomial_rvs) != 1:
123+
# Question: Can we apply conjugacy when RV is used by more than one binomial?
124+
return None
125+
126+
[binomial_rv] = binomial_rvs
127+
128+
binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
129+
if binomial_model_var is None:
130+
return None
131+
132+
# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
133+
a, b = get_dist_params(beta_rv)
134+
n, _ = get_dist_params(binomial_rv)
135+
136+
# Use value of y in new graph to avoid circularity
137+
y = binomial_model_var.owner.inputs[1]
138+
139+
conjugate_a = a + y
140+
conjugate_b = b + (n - y)
141+
extra_dims = range(binomial_rv.type.ndim - beta_rv.type.ndim)
142+
if extra_dims:
143+
conjugate_a = conjugate_a.sum(extra_dims)
144+
conjugate_b = conjugate_b.sum(extra_dims)
145+
conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b)
146+
147+
new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
148+
new_beta_free_rv = create_untransformed_free_rv(fgraph, new_beta_rv, beta_free_rv.name, beta_dims)
149+
return [new_beta_free_rv]
150+
151+
152+
posterior_optimization_db.register(
153+
beta_binomial_conjugacy.__name__,
154+
beta_binomial_conjugacy,
155+
"conjugacy"
156+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import numpy as np
2+
3+
from pymc_experimental.utils.ofg import inline_ofg_outputs
4+
from pytensor.compile.builders import OpFromGraph
5+
from pymc.logprob.abstract import MeasurableOp, _logprob
6+
from pymc.distributions.distribution import _support_point
7+
from pymc.step_methods.compound import BlockedStep, StepMethodState, Competence
8+
from pymc.model.core import modelcontext
9+
from pymc.util import get_value_vars_from_user_vars
10+
from pymc.pytensorf import compile_pymc
11+
from pytensor import shared
12+
from pytensor.tensor.random.type import RandomGeneratorType
13+
from pytensor.link.jax.linker import JAXLinker
14+
from pymc.initial_point import PointType
15+
16+
class ConjugateRV(OpFromGraph, MeasurableOp):
17+
"""Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression.
18+
19+
For partial step samplers to work, the logp and initial point correspond to the original RV
20+
while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the
21+
conjugate posterior expression (i.e., taking forward random draws).
22+
"""
23+
24+
25+
@_logprob.register(ConjugateRV)
26+
def conjugate_rv_logp(op, values, rv, *params, **kwargs):
27+
# Logp is the same as the original RV
28+
return _logprob(rv.owner.op, values, *rv.owner.inputs)
29+
30+
31+
@_support_point.register(ConjugateRV)
32+
def conjugate_rv_support_point(op, conjugate_rv, rv, *params):
33+
# Support point is the same as the original RV
34+
return _support_point(rv.owner.op, rv, *rv.owner.inputs)
35+
36+
37+
class ConjugateRVSampler(BlockedStep):
38+
name = "conjugate_rv_sampler"
39+
_state_class = StepMethodState
40+
41+
def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs):
42+
if len(vars) != 1:
43+
raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time")
44+
45+
model = modelcontext(model)
46+
[value] = get_value_vars_from_user_vars(vars, model=model)
47+
rv = model.values_to_rvs[value]
48+
self.vars = (value,)
49+
self.rv_name = value.name
50+
51+
if model.rvs_to_transforms[rv] is not None:
52+
raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed")
53+
54+
rv_and_posterior_rv_node = rv.owner
55+
op = rv_and_posterior_rv_node.op
56+
if not isinstance(op, ConjugateRV):
57+
raise ValueError("Variable must be a ConjugateRV")
58+
59+
# Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables
60+
value_inputs = model.replace_rvs_by_values(
61+
[rv_and_posterior_rv_node.outputs[1]],
62+
)[0].owner.inputs
63+
# Inline the ConjugateRV graph to only compile `posterior_rv`
64+
_, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs)
65+
66+
if compile_kwargs is None:
67+
compile_kwargs = {}
68+
self.posterior_fn = compile_pymc(
69+
model.value_vars,
70+
posterior_rv,
71+
random_seed=rng,
72+
on_unused_input="ignore",
73+
**compile_kwargs,
74+
)
75+
self.posterior_fn.trust_input = True
76+
if isinstance(self.posterior_fn.maker.linker, JAXLinker):
77+
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
78+
# used internally are not the ones that `function.get_shared()` returns.
79+
raise ValueError("ConjugateRVSampler is not compatible with JAX backend")
80+
81+
def set_rng(self, rng: np.random.Generator):
82+
# Copy the function and replace any shared RNGs
83+
# This is needed so that it can work correctly with multiple traces
84+
# This will be costly if set_rng is called too often!
85+
shared_rngs = [
86+
var for var in self.posterior_fn.get_shared() if isinstance(var.type, RandomGeneratorType)
87+
]
88+
n_shared_rngs = len(shared_rngs)
89+
swap = {
90+
old_shared_rng: shared(rng, borrow=True)
91+
for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True)
92+
}
93+
self.posterior_fn = self.posterior_fn.copy(swap=swap)
94+
95+
def step(self, point: PointType) -> tuple[PointType, list]:
96+
new_point = point.copy()
97+
new_point[self.rv_name] = self.posterior_fn(**point)
98+
return new_point, []
99+
100+
@staticmethod
101+
def competence(var, has_grad):
102+
"""BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2."""
103+
if isinstance(var.owner.op, ConjugateRV):
104+
return Competence.IDEAL
105+
106+
return Competence.INCOMPATIBLE

pymc_experimental/utils/ofg.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from pytensor.graph.basic import Variable
2+
from pytensor.graph.replace import clone_replace
3+
from pytensor.compile.builders import OpFromGraph
4+
from typing import Sequence
5+
6+
7+
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
8+
"""Inline the inner graph (outputs) of an OpFromGraph Op.
9+
10+
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
11+
the inner graph.
12+
"""
13+
return clone_replace(
14+
op.inner_outputs,
15+
replace=tuple(zip(op.inner_inputs, inputs)),
16+
)

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ addopts = [
88
]
99

1010
filterwarnings =[
11-
"error",
11+
# "error",
1212
# Raised by arviz when the model_builder class adds non-standard group names to InferenceData
1313
"ignore::UserWarning:arviz.data.inference_data",
1414

0 commit comments

Comments
 (0)