Skip to content

Commit 493855f

Browse files
committed
Add Normal summary stats optimization
1 parent 3c19a5e commit 493855f

File tree

9 files changed

+166
-1
lines changed

9 files changed

+166
-1
lines changed
+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ruff: noqa: F401
2+
# Add rewrites to the optimization DBs
3+
import pymc_experimental.sampling.optimizations.summary_stats

pymc_experimental/sampling/mcmc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def opt_sample(
5353
fgraph, _ = fgraph_from_model(model)
5454

5555
if rewriter is None:
56-
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=[]))
56+
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"]))
5757
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)
5858

5959
if verbose:

pymc_experimental/sampling/optimizations/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytensor.tensor as pt
2+
3+
from pymc.distributions import Gamma, Normal
4+
from pymc.model.fgraph import ModelObservedRV, model_observed_rv
5+
from pytensor.graph.fg import FunctionGraph
6+
from pytensor.graph.rewriting.basic import node_rewriter
7+
8+
from pymc_experimental.sampling.mcmc import posterior_optimization_db
9+
10+
11+
@node_rewriter(tracks=[ModelObservedRV])
12+
def summary_stats_normal(fgraph: FunctionGraph, node):
13+
"""Applies the equivalence (up to a normalizing constant) described in:
14+
15+
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
16+
"""
17+
[observed_rv] = node.outputs
18+
[rv, data] = node.inputs
19+
20+
if not isinstance(rv.owner.op, Normal):
21+
return None
22+
23+
# Check the normal RV is not just a scalar
24+
if all(rv.type.broadcastable):
25+
return None
26+
27+
# Check that the observed RV is not used anywhere else (like a Potential or Deterministic)
28+
# There should be only one use: as an "output"
29+
if len(fgraph.clients[observed_rv]) > 1:
30+
return None
31+
32+
mu, sigma = rv.owner.op.dist_params(rv.owner)
33+
34+
# Check if mu and sigma are scalar RVs
35+
if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable):
36+
return None
37+
38+
# Check that mu and sigma are not used anywhere else
39+
# Note: This is too restrictive, it's fine if they're used in Deterministics!
40+
# There should only be two uses: as an "output" and as the param of the `rv`
41+
if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2:
42+
return None
43+
44+
# Remove expand_dims
45+
mu = mu.squeeze()
46+
sigma = sigma.squeeze()
47+
48+
# Apply the rewrite
49+
mean_data = pt.mean(data)
50+
mean_data.name = None
51+
var_data = pt.var(data, ddof=1)
52+
var_data.name = None
53+
N = data.size
54+
sqrt_N = pt.sqrt(N)
55+
nm1_over2 = (N - 1) / 2
56+
57+
observed_mean = model_observed_rv(
58+
Normal.dist(mu=mu, sigma=sigma / sqrt_N),
59+
mean_data,
60+
)
61+
observed_mean.name = f"{rv.name}_mean"
62+
63+
observed_var = model_observed_rv(
64+
Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma**2)),
65+
var_data,
66+
)
67+
observed_var.name = f"{rv.name}_var"
68+
69+
fgraph.add_output(observed_mean, import_missing=True)
70+
fgraph.add_output(observed_var, import_missing=True)
71+
fgraph.remove_node(node)
72+
# Just so it shows in the profile for verbose=True,
73+
# It won't do anything because node is not in the fgraph anymore
74+
return [node.out.copy()]
75+
76+
77+
posterior_optimization_db.register(
78+
summary_stats_normal.__name__, summary_stats_normal, "summary_stats"
79+
)

tests/sampling/__init__.py

Whitespace-only changes.

tests/sampling/mcmc/__init__.py

Whitespace-only changes.

tests/sampling/mcmc/test_mcmc.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
3+
from pymc.distributions import HalfNormal, Normal
4+
from pymc.model.core import Model
5+
from pymc.sampling.mcmc import sample
6+
7+
from pymc_experimental import opt_sample
8+
9+
10+
def test_sample_opt_summary_stats(capsys):
11+
rng = np.random.default_rng(3)
12+
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))
13+
14+
with Model() as m:
15+
mu = Normal("mu")
16+
sigma = HalfNormal("sigma")
17+
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)
18+
19+
sample_kwargs = dict(
20+
chains=1, tune=500, draws=500, compute_convergence_checks=False, progressbar=False
21+
)
22+
idata = sample(**sample_kwargs)
23+
opt_idata = opt_sample(**sample_kwargs, verbose=True)
24+
25+
captured_out = capsys.readouterr().out
26+
assert "Applied optimization: summary_stats_normal 1x" in captured_out
27+
28+
assert opt_idata.posterior.sizes["chain"] == 1
29+
assert opt_idata.posterior.sizes["draw"] == 500
30+
np.testing.assert_allclose(
31+
idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-3
32+
)
33+
np.testing.assert_allclose(
34+
idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2
35+
)
36+
assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time

tests/sampling/optimizations/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
3+
from pymc.distributions import HalfNormal, Normal
4+
from pymc.model.core import Model
5+
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph
6+
from pytensor.graph.rewriting.basic import out2in
7+
8+
from pymc_experimental.sampling.optimizations.summary_stats import summary_stats_normal
9+
10+
11+
def test_summary_stats_normal():
12+
rng = np.random.default_rng(3)
13+
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))
14+
15+
with Model() as m:
16+
mu = Normal("mu")
17+
sigma = HalfNormal("sigma")
18+
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)
19+
20+
assert len(m.free_RVs) == 2
21+
assert len(m.observed_RVs) == 1
22+
23+
fgraph, _ = fgraph_from_model(m)
24+
summary_stats_rewrite = out2in(summary_stats_normal)
25+
_ = summary_stats_rewrite.apply(fgraph)
26+
new_m = model_from_fgraph(fgraph)
27+
28+
assert len(new_m.free_RVs) == 2
29+
assert len(new_m.observed_RVs) == 2
30+
31+
# Confirm equivalent (up to an additive normalization constant)
32+
m_logp = m.compile_logp()
33+
new_m_logp = new_m.compile_logp()
34+
35+
ip = m.initial_point()
36+
first_logp_diff = m_logp(ip) - new_m_logp(ip)
37+
38+
ip["mu"] += 0.5
39+
ip["sigma_log__"] += 1.5
40+
second_logp_diff = m_logp(ip) - new_m_logp(ip)
41+
42+
np.testing.assert_allclose(first_logp_diff, second_logp_diff)
43+
44+
# dlogp should be the same
45+
m_dlogp = m.compile_dlogp()
46+
new_m_dlogp = new_m.compile_dlogp()
47+
np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip))

0 commit comments

Comments
 (0)