Skip to content

Commit fa8dfff

Browse files
committed
Add Normal summary stats optimization
1 parent e8f490c commit fa8dfff

File tree

6 files changed

+170
-1
lines changed

6 files changed

+170
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# ruff: noqa: F401
2+
# Add rewrites to the optimization DBs
3+
import pymc_experimental.sampling.optimizations.summary_stats
4+
5+
from pymc_experimental.sampling.optimizations.optimize import (
6+
optimize_model_for_mcmc_sampling,
7+
posterior_optimization_db,
8+
)
9+
10+
__all__ = [
11+
"posterior_optimization_db",
12+
"optimize_model_for_mcmc_sampling",
13+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytensor.tensor as pt
2+
3+
from pymc.distributions import Gamma, Normal
4+
from pymc.model.fgraph import ModelObservedRV, model_observed_rv
5+
from pytensor.graph.fg import FunctionGraph
6+
from pytensor.graph.rewriting.basic import node_rewriter
7+
8+
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db
9+
10+
11+
@node_rewriter(tracks=[ModelObservedRV])
12+
def summary_stats_normal(fgraph: FunctionGraph, node):
13+
"""Applies the equivalence (up to a normalizing constant) described in:
14+
15+
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
16+
"""
17+
[observed_rv] = node.outputs
18+
[rv, data] = node.inputs
19+
20+
if not isinstance(rv.owner.op, Normal):
21+
return None
22+
23+
# Check the normal RV is not just a scalar
24+
if all(rv.type.broadcastable):
25+
return None
26+
27+
# Check that the observed RV is not used anywhere else (like a Potential or Deterministic)
28+
# There should be only one use: as an "output"
29+
if len(fgraph.clients[observed_rv]) > 1:
30+
return None
31+
32+
mu, sigma = rv.owner.op.dist_params(rv.owner)
33+
34+
# Check if mu and sigma are scalar RVs
35+
if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable):
36+
return None
37+
38+
# Check that mu and sigma are not used anywhere else
39+
# Note: This is too restrictive, it's fine if they're used in Deterministics!
40+
# There should only be two uses: as an "output" and as the param of the `rv`
41+
if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2:
42+
return None
43+
44+
# Remove expand_dims
45+
mu = mu.squeeze()
46+
sigma = sigma.squeeze()
47+
48+
# Apply the rewrite
49+
mean_data = pt.mean(data)
50+
mean_data.name = None
51+
var_data = pt.var(data, ddof=1)
52+
var_data.name = None
53+
N = data.size
54+
sqrt_N = pt.sqrt(N)
55+
nm1_over2 = (N - 1) / 2
56+
57+
observed_mean = model_observed_rv(
58+
Normal.dist(mu=mu, sigma=sigma / sqrt_N),
59+
mean_data,
60+
)
61+
observed_mean.name = f"{rv.name}_mean"
62+
63+
observed_var = model_observed_rv(
64+
Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma**2)),
65+
var_data,
66+
)
67+
observed_var.name = f"{rv.name}_var"
68+
69+
fgraph.add_output(observed_mean, import_missing=True)
70+
fgraph.add_output(observed_var, import_missing=True)
71+
fgraph.remove_node(node)
72+
# Just so it shows in the profile for verbose=True,
73+
# It won't do anything because node is not in the fgraph anymore
74+
return [node.out.copy()]
75+
76+
77+
posterior_optimization_db.register(
78+
summary_stats_normal.__name__, summary_stats_normal, "default", "summary_stats"
79+
)

tests/sampling/__init__.py

Whitespace-only changes.

tests/sampling/mcmc/test_mcmc.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import numpy as np
12
import pytest
23

3-
from pymc.distributions import Beta, Binomial, InverseGamma
4+
from pymc.distributions import Beta, Binomial, HalfNormal, InverseGamma, Normal
45
from pymc.model.core import Model
6+
from pymc.sampling.mcmc import sample
57
from pymc.step_methods import Slice
68

79
from pymc_experimental import opt_sample
@@ -18,3 +20,35 @@ def test_custom_step_raises():
1820
ValueError, match="The `step` argument is not supported in `opt_sample`"
1921
):
2022
opt_sample(step=Slice([a, b]))
23+
24+
25+
def test_sample_opt_summary_stats(capsys):
26+
rng = np.random.default_rng(3)
27+
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))
28+
29+
with Model() as m:
30+
mu = Normal("mu")
31+
sigma = HalfNormal("sigma")
32+
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)
33+
34+
sample_kwargs = dict(
35+
chains=1, tune=500, draws=500, compute_convergence_checks=False, progressbar=False
36+
)
37+
idata = sample(**sample_kwargs)
38+
# TODO: Make extract_data more robust to avoid this warning/error
39+
# Or alternatively extract data on the original model, not the optimized one
40+
with pytest.warns(UserWarning, match="Could not extract data from symbolic observation"):
41+
opt_idata = opt_sample(**sample_kwargs, verbose=True)
42+
43+
captured_out = capsys.readouterr().out
44+
assert "Applied optimization: summary_stats_normal 1x" in captured_out
45+
46+
assert opt_idata.posterior.sizes["chain"] == 1
47+
assert opt_idata.posterior.sizes["draw"] == 500
48+
np.testing.assert_allclose(
49+
idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-2
50+
)
51+
np.testing.assert_allclose(
52+
idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2
53+
)
54+
assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time

tests/sampling/optimizations/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
3+
from pymc.distributions import HalfNormal, Normal
4+
from pymc.model.core import Model
5+
6+
from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling
7+
8+
9+
def test_summary_stats_normal():
10+
rng = np.random.default_rng(3)
11+
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))
12+
13+
with Model() as m:
14+
mu = Normal("mu")
15+
sigma = HalfNormal("sigma")
16+
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)
17+
18+
assert len(m.free_RVs) == 2
19+
assert len(m.observed_RVs) == 1
20+
21+
new_m, rewrite_counters = optimize_model_for_mcmc_sampling(m)
22+
assert "summary_stats_normal" in (r.name for rc in rewrite_counters for r in rc)
23+
24+
assert len(new_m.free_RVs) == 2
25+
assert len(new_m.observed_RVs) == 2
26+
27+
# Confirm equivalent (up to an additive normalization constant)
28+
m_logp = m.compile_logp()
29+
new_m_logp = new_m.compile_logp()
30+
31+
ip = m.initial_point()
32+
first_logp_diff = m_logp(ip) - new_m_logp(ip)
33+
34+
ip["mu"] += 0.5
35+
ip["sigma_log__"] += 1.5
36+
second_logp_diff = m_logp(ip) - new_m_logp(ip)
37+
38+
np.testing.assert_allclose(first_logp_diff, second_logp_diff)
39+
40+
# dlogp should be the same
41+
m_dlogp = m.compile_dlogp()
42+
new_m_dlogp = new_m.compile_dlogp()
43+
np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip))

0 commit comments

Comments
 (0)