Skip to content

Commit e051965

Browse files
committed
Add Normal summary stats optimization
1 parent 4e143ce commit e051965

File tree

9 files changed

+158
-1
lines changed

9 files changed

+158
-1
lines changed
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Add rewrites to the optimization DBs
2+
import pymc_experimental.sampling.optimizations.summary_stats

pymc_experimental/sampling/mcmc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def opt_sample(
5353
fgraph, _ = fgraph_from_model(model)
5454

5555
if rewriter is None:
56-
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=[]))
56+
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"]))
5757
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)
5858

5959
if verbose:

pymc_experimental/sampling/optimizations/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from pytensor.graph.fg import FunctionGraph
2+
from pytensor.graph.rewriting.basic import node_rewriter
3+
from pymc.model.fgraph import ModelObservedRV, model_observed_rv
4+
from pymc_experimental.sampling.mcmc import posterior_optimization_db
5+
import pytensor.tensor as pt
6+
from pymc.distributions import Normal, Gamma
7+
8+
9+
@node_rewriter(tracks=[ModelObservedRV])
10+
def summary_stats_normal(fgraph: FunctionGraph, node):
11+
"""This applies the equivalence (up to a normalizing constant) described in:
12+
13+
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
14+
"""
15+
[observed_rv] = node.outputs
16+
[rv, data] = node.inputs
17+
18+
if not isinstance(rv.owner.op, Normal):
19+
return None
20+
21+
# Check the normal RV is not just a scalar
22+
if all(rv.type.broadcastable):
23+
return None
24+
25+
# Check that the observed RV is not used anywhere else (like a Potential or Deterministic)
26+
# There should be only one use: as an "output"
27+
if len(fgraph.clients[observed_rv]) > 1:
28+
return None
29+
30+
mu, sigma = rv.owner.op.dist_params(rv.owner)
31+
32+
# Check if mu and sigma are scalar RVs
33+
if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable):
34+
return None
35+
36+
# Check that mu and sigma are not used anywhere else
37+
# Note: This is too restrictive, it's fine if they're used in Deterministics!
38+
# There should only be two uses: as an "output" and as the param of the `rv`
39+
if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2:
40+
return None
41+
42+
# Remove expand_dims
43+
mu = mu.squeeze()
44+
sigma = sigma.squeeze()
45+
46+
# Apply the rewrite
47+
mean_data = pt.mean(data)
48+
mean_data.name = None
49+
var_data = pt.var(data, ddof=1)
50+
var_data.name = None
51+
N = data.size
52+
sqrt_N = pt.sqrt(N)
53+
nm1_over2 = (N - 1) / 2
54+
55+
observed_mean = model_observed_rv(
56+
Normal.dist(mu=mu, sigma=sigma / sqrt_N),
57+
mean_data,
58+
)
59+
observed_mean.name = f"{rv.name}_mean"
60+
61+
observed_var = model_observed_rv(
62+
Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma ** 2)),
63+
var_data,
64+
)
65+
observed_var.name = f"{rv.name}_var"
66+
67+
fgraph.add_output(observed_mean, import_missing=True)
68+
fgraph.add_output(observed_var, import_missing=True)
69+
fgraph.remove_node(node)
70+
# Just so it shows in the profile for verbose=True,
71+
# It won't do anything because node is not in the fgraph anymore
72+
return [node.out.copy()]
73+
74+
75+
posterior_optimization_db.register(
76+
summary_stats_normal.__name__,
77+
summary_stats_normal,
78+
"summary_stats"
79+
)

tests/sampling/__init__.py

Whitespace-only changes.

tests/sampling/mcmc/__init__.py

Whitespace-only changes.

tests/sampling/mcmc/test_mcmc.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
from pymc.model.core import Model
3+
from pymc.distributions import Normal, HalfNormal
4+
from pymc.sampling.mcmc import sample
5+
6+
from pymc_experimental import opt_sample
7+
8+
9+
def test_sample_opt_summary_stats(capsys):
10+
rng = np.random.default_rng(3)
11+
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))
12+
13+
with Model() as m:
14+
mu = Normal("mu")
15+
sigma = HalfNormal("sigma")
16+
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)
17+
18+
sample_kwargs = dict(chains=1, tune=500, draws=500, compute_convergence_checks=False, progressbar=False)
19+
idata = sample(**sample_kwargs)
20+
opt_idata = opt_sample(**sample_kwargs, verbose=True)
21+
22+
captured_out = capsys.readouterr().out
23+
assert "Applied optimization: summary_stats_normal 1x" in captured_out
24+
25+
assert opt_idata.posterior.sizes["chain"] == 1
26+
assert opt_idata.posterior.sizes["draw"] == 500
27+
np.testing.assert_allclose(idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-3)
28+
np.testing.assert_allclose(idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2)
29+
assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time

tests/sampling/optimizations/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
from pymc.model.core import Model
3+
from pymc.distributions import Normal, HalfNormal
4+
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph
5+
from pytensor.graph.rewriting.basic import out2in
6+
7+
from pymc_experimental.sampling.optimizations.summary_stats import summary_stats_normal
8+
9+
10+
def test_summary_stats_normal():
11+
12+
rng = np.random.default_rng(3)
13+
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))
14+
15+
with Model() as m:
16+
mu = Normal("mu")
17+
sigma = HalfNormal("sigma")
18+
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)
19+
20+
assert len(m.free_RVs) == 2
21+
assert len(m.observed_RVs) == 1
22+
23+
fgraph, _ = fgraph_from_model(m)
24+
summary_stats_rewrite = out2in(summary_stats_normal)
25+
_ = summary_stats_rewrite.apply(fgraph)
26+
new_m = model_from_fgraph(fgraph)
27+
28+
assert len(new_m.free_RVs) == 2
29+
assert len(new_m.observed_RVs) == 2
30+
31+
# Confirm equivalent (up to an additive normalization constant)
32+
m_logp = m.compile_logp()
33+
new_m_logp = new_m.compile_logp()
34+
35+
ip = m.initial_point()
36+
first_logp_diff = m_logp(ip) - new_m_logp(ip)
37+
38+
ip["mu"] += 0.5
39+
ip["sigma_log__"] += 1.5
40+
second_logp_diff = m_logp(ip) - new_m_logp(ip)
41+
42+
np.testing.assert_allclose(first_logp_diff, second_logp_diff)
43+
44+
# dlogp should be the same
45+
m_dlogp = m.compile_dlogp()
46+
new_m_dlogp = new_m.compile_dlogp()
47+
np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip))

0 commit comments

Comments
 (0)