Skip to content

Commit c417476

Browse files
committed
Add metropolis sampling state
1 parent 0a3f5e4 commit c417476

File tree

3 files changed

+166
-12
lines changed

3 files changed

+166
-12
lines changed

pymc/step_methods/metropolis.py

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from collections.abc import Callable
15+
from dataclasses import field
16+
from typing import Any
1517

1618
import numpy as np
1719
import numpy.random as nr
@@ -40,7 +42,8 @@
4042
StatsType,
4143
metrop_select,
4244
)
43-
from pymc.step_methods.compound import Competence
45+
from pymc.step_methods.compound import Competence, StepMethodState
46+
from pymc.step_methods.state import dataclass_state
4447

4548
__all__ = [
4649
"Metropolis",
@@ -111,18 +114,40 @@ def __call__(self, num_draws=None, rng: np.random.Generator | None = None):
111114
return np.dot(self.chol, b)
112115

113116

117+
@dataclass_state
118+
class MetropolisState(StepMethodState):
119+
scaling: np.ndarray
120+
tune: bool
121+
steps_until_tune: float
122+
tune_interval: float
123+
accepted_sum: np.ndarray
124+
accept_rate_iter: np.ndarray
125+
accepted_iter: np.ndarray
126+
enum_dims: np.ndarray
127+
128+
discrete: np.ndarray = field(metadata={"frozen": True})
129+
any_discrete: bool = field(metadata={"frozen": True})
130+
all_discrete: bool = field(metadata={"frozen": True})
131+
elemwise_update: bool = field(metadata={"frozen": True})
132+
_untuned_settings: dict[str, np.ndarray | float] = field(metadata={"frozen": True})
133+
mode: Any = field(metadata={"frozen": True})
134+
135+
114136
class Metropolis(ArrayStepShared):
115137
"""Metropolis-Hastings sampling step"""
116138

117139
name = "metropolis"
118140

141+
default_blocked = False
119142
stats_dtypes_shapes = {
120143
"accept": (np.float64, []),
121144
"accepted": (np.float64, []),
122145
"tune": (bool, []),
123146
"scaling": (np.float64, []),
124147
}
125148

149+
_state_class = MetropolisState
150+
126151
def __init__(
127152
self,
128153
vars=None,
@@ -346,6 +371,15 @@ def tune(scale, acc_rate):
346371
)
347372

348373

374+
@dataclass_state
375+
class BinaryMetropolisState(StepMethodState):
376+
tune: bool
377+
accepted: int
378+
scaling: float
379+
tune_interval: int
380+
steps_until_tune: int
381+
382+
349383
class BinaryMetropolis(ArrayStep):
350384
"""Metropolis-Hastings optimized for binary variables
351385
@@ -375,7 +409,9 @@ class BinaryMetropolis(ArrayStep):
375409
"p_jump": (np.float64, []),
376410
}
377411

378-
def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
412+
_state_class = BinaryMetropolisState
413+
414+
def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None, rng=None):
379415
model = pm.modelcontext(model)
380416

381417
self.scaling = scaling
@@ -389,7 +425,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
389425
if not all([v.dtype in pm.discrete_types for v in vars]):
390426
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")
391427

392-
super().__init__(vars, [model.compile_logp()])
428+
super().__init__(vars, [model.compile_logp()], rng=rng)
393429

394430
def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
395431
logp = args[0]
@@ -445,6 +481,14 @@ def competence(var):
445481
return Competence.INCOMPATIBLE
446482

447483

484+
@dataclass_state
485+
class BinaryGibbsMetropolisState(StepMethodState):
486+
tune: bool
487+
transit_p: int
488+
shuffle_dims: bool
489+
order: list
490+
491+
448492
class BinaryGibbsMetropolis(ArrayStep):
449493
"""A Metropolis-within-Gibbs step method optimized for binary variables
450494
@@ -472,7 +516,9 @@ class BinaryGibbsMetropolis(ArrayStep):
472516
"tune": (bool, []),
473517
}
474518

475-
def __init__(self, vars, order="random", transit_p=0.8, model=None):
519+
_state_class = BinaryGibbsMetropolisState
520+
521+
def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None):
476522
model = pm.modelcontext(model)
477523

478524
# Doesn't actually tune, but it's required to emit a sampler stat
@@ -498,7 +544,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
498544
if not all([v.dtype in pm.discrete_types for v in vars]):
499545
raise ValueError("All variables must be binary for BinaryGibbsMetropolis")
500546

501-
super().__init__(vars, [model.compile_logp()])
547+
super().__init__(vars, [model.compile_logp()], rng=rng)
502548

503549
def reset_tuning(self):
504550
# There are no tuning parameters in this step method.
@@ -557,6 +603,13 @@ def competence(var):
557603
return Competence.INCOMPATIBLE
558604

559605

606+
@dataclass_state
607+
class CategoricalGibbsMetropolisState(StepMethodState):
608+
shuffle_dims: bool
609+
dimcats: list[tuple]
610+
tune: bool
611+
612+
560613
class CategoricalGibbsMetropolis(ArrayStep):
561614
"""A Metropolis-within-Gibbs step method optimized for categorical variables.
562615
@@ -573,6 +626,8 @@ class CategoricalGibbsMetropolis(ArrayStep):
573626
"tune": (bool, []),
574627
}
575628

629+
_state_class = CategoricalGibbsMetropolisState
630+
576631
def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None):
577632
model = pm.modelcontext(model)
578633

@@ -728,6 +783,18 @@ def competence(var):
728783
return Competence.INCOMPATIBLE
729784

730785

786+
@dataclass_state
787+
class DEMetropolisState(StepMethodState):
788+
scaling: np.ndarray
789+
lamb: float
790+
tune: str | None
791+
tune_interval: int
792+
steps_until_tune: int
793+
accepted: int
794+
795+
mode: Any = field(metadata={"frozen": True})
796+
797+
731798
class DEMetropolis(PopulationArrayStepShared):
732799
"""
733800
Differential Evolution Metropolis sampling step.
@@ -778,6 +845,8 @@ class DEMetropolis(PopulationArrayStepShared):
778845
"lambda": (np.float64, []),
779846
}
780847

848+
_state_class = DEMetropolisState
849+
781850
def __init__(
782851
self,
783852
vars=None,
@@ -789,6 +858,7 @@ def __init__(
789858
tune_interval=100,
790859
model=None,
791860
mode=None,
861+
rng=None,
792862
**kwargs,
793863
):
794864
model = pm.modelcontext(model)
@@ -824,7 +894,7 @@ def __init__(
824894

825895
shared = pm.make_shared_replacements(initial_values, vars, model)
826896
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
827-
super().__init__(vars, shared)
897+
super().__init__(vars, shared, rng=rng)
828898

829899
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
830900
point_map_info = q0.point_map_info
@@ -843,9 +913,11 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
843913

844914
# differential evolution proposal
845915
# select two other chains
846-
ir1, ir2 = np.random.choice(self.other_chains, 2, replace=False)
847-
r1 = DictToArrayBijection.map(self.population[ir1])
848-
r2 = DictToArrayBijection.map(self.population[ir2])
916+
if self.other_chains is None: # pragma: no cover
917+
raise RuntimeError("Population sampler has not been linked to the other chains")
918+
ir1, ir2 = self.rng.choice(self.other_chains, 2, replace=False)
919+
r1 = DictToArrayBijection.map(self.population[ir1]) # type: ignore
920+
r2 = DictToArrayBijection.map(self.population[ir2]) # type: ignore
849921
# propose a jump
850922
q = floatX(q0d + self.lamb * (r1.data - r2.data) + epsilon)
851923

@@ -872,6 +944,21 @@ def competence(var, has_grad):
872944
return Competence.COMPATIBLE
873945

874946

947+
@dataclass_state
948+
class DEMetropolisZState(StepMethodState):
949+
scaling: np.ndarray
950+
lamb: float
951+
tune: bool
952+
tune_target: str | None
953+
tune_interval: int
954+
steps_until_tune: int
955+
accepted: int
956+
_history: list
957+
958+
_untuned_settings: dict[str, np.ndarray | float] = field(metadata={"frozen": True})
959+
mode: Any = field(metadata={"frozen": True})
960+
961+
875962
class DEMetropolisZ(ArrayStepShared):
876963
"""
877964
Adaptive Differential Evolution Metropolis sampling step that uses the past to inform jumps.
@@ -925,6 +1012,8 @@ class DEMetropolisZ(ArrayStepShared):
9251012
"lambda": (np.float64, []),
9261013
}
9271014

1015+
_state_class = DEMetropolisZState
1016+
9281017
def __init__(
9291018
self,
9301019
vars=None,
@@ -937,6 +1026,7 @@ def __init__(
9371026
tune_drop_fraction: float = 0.9,
9381027
model=None,
9391028
mode=None,
1029+
rng=None,
9401030
**kwargs,
9411031
):
9421032
model = pm.modelcontext(model)
@@ -984,7 +1074,7 @@ def __init__(
9841074

9851075
shared = pm.make_shared_replacements(initial_values, vars, model)
9861076
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
987-
super().__init__(vars, shared)
1077+
super().__init__(vars, shared, rng=rng)
9881078

9891079
def reset_tuning(self):
9901080
"""Resets the tuned sampler parameters and history to their initial values."""

tests/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,14 @@ def simple_normal(bounded_prior=False):
186186
pm.Normal("X_obs", mu=mu_i, sigma=sigma, observed=x0)
187187

188188
return model.initial_point(), model, None
189+
190+
191+
def simple_binary():
192+
p1 = 0.5
193+
p2 = 0.5
194+
195+
with pm.Model() as model:
196+
pm.Bernoulli("d1", p=p1)
197+
pm.Bernoulli("d2", p=p2)
198+
199+
return model.initial_point(), model, (p1, p2)

tests/step_methods/test_metropolis.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import warnings
1616

17+
from copy import deepcopy
18+
1719
import arviz as az
1820
import numpy as np
1921
import numpy.testing as npt
@@ -24,17 +26,25 @@
2426

2527
from pymc.step_methods.metropolis import (
2628
BinaryGibbsMetropolis,
29+
BinaryMetropolis,
2730
CategoricalGibbsMetropolis,
2831
DEMetropolis,
2932
DEMetropolisZ,
3033
Metropolis,
3134
MultivariateNormalProposal,
3235
NormalProposal,
3336
)
37+
from pymc.step_methods.state import equal_dataclass_values
3438
from pymc.testing import fast_unstable_sampling_mode
3539
from tests import sampler_fixtures as sf
36-
from tests.helpers import RVsAssignmentStepsTester, StepMethodTester
37-
from tests.models import mv_simple, mv_simple_discrete, simple_categorical
40+
from tests.helpers import RVsAssignmentStepsTester, StepMethodTester, equal_sampling_states
41+
from tests.models import (
42+
mv_simple,
43+
mv_simple_discrete,
44+
simple_binary,
45+
simple_categorical,
46+
simple_model,
47+
)
3848

3949
SEED = sum(ord(c) for c in "test_metropolis")
4050

@@ -47,6 +57,7 @@ class TestMetropolisUniform(sf.MetropolisFixture, sf.UniformFixture):
4757
min_n_eff = 10000
4858
rtol = 0.1
4959
atol = 0.05
60+
ks_thin = 10
5061
step_args = {"rng": np.random.default_rng(SEED)}
5162

5263

@@ -367,3 +378,45 @@ def test_discrete_steps(self, step, step_kwargs):
367378
)
368379
def test_continuous_steps(self, step, step_kwargs):
369380
self.continuous_steps(step, step_kwargs)
381+
382+
383+
@pytest.mark.parametrize(
384+
["step_method", "model_fn"],
385+
[
386+
[Metropolis, simple_model],
387+
[BinaryMetropolis, simple_binary],
388+
[BinaryGibbsMetropolis, simple_binary],
389+
[CategoricalGibbsMetropolis, simple_categorical],
390+
[DEMetropolis, simple_model],
391+
[DEMetropolisZ, simple_model],
392+
],
393+
)
394+
def test_sampling_state(step_method, model_fn):
395+
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
396+
initial_point, model, _ = model_fn()
397+
with model:
398+
sampler = step_method(model.value_vars)
399+
if hasattr(sampler, "link_population"):
400+
sampler.link_population([initial_point] * 100, 0)
401+
sampler_orig = deepcopy(sampler)
402+
state_orig = sampler_orig.sampling_state
403+
404+
sample1, stat1 = sampler.step(initial_point)
405+
sampler.tune = False
406+
407+
final_state1 = sampler.sampling_state
408+
409+
assert not equal_sampling_states(final_state1, state_orig)
410+
411+
sampler.sampling_state = state_orig
412+
413+
assert equal_sampling_states(sampler.sampling_state, state_orig)
414+
415+
sample2, stat2 = sampler.step(initial_point)
416+
sampler.tune = False
417+
418+
final_state2 = sampler.sampling_state
419+
420+
assert equal_sampling_states(final_state1, final_state2)
421+
assert equal_dataclass_values(sample1, sample2)
422+
assert equal_dataclass_values(stat1, stat2)

0 commit comments

Comments
 (0)