12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from collections .abc import Callable
15
+ from dataclasses import field
16
+ from typing import Any
15
17
16
18
import numpy as np
17
19
import numpy .random as nr
40
42
StatsType ,
41
43
metrop_select ,
42
44
)
43
- from pymc .step_methods .compound import Competence
45
+ from pymc .step_methods .compound import Competence , StepMethodState
46
+ from pymc .step_methods .state import dataclass_state
44
47
45
48
__all__ = [
46
49
"Metropolis" ,
@@ -111,18 +114,40 @@ def __call__(self, num_draws=None, rng: np.random.Generator | None = None):
111
114
return np .dot (self .chol , b )
112
115
113
116
117
+ @dataclass_state
118
+ class MetropolisState (StepMethodState ):
119
+ scaling : np .ndarray
120
+ tune : bool
121
+ steps_until_tune : float
122
+ tune_interval : float
123
+ accepted_sum : np .ndarray
124
+ accept_rate_iter : np .ndarray
125
+ accepted_iter : np .ndarray
126
+ enum_dims : np .ndarray
127
+
128
+ discrete : np .ndarray = field (metadata = {"frozen" : True })
129
+ any_discrete : bool = field (metadata = {"frozen" : True })
130
+ all_discrete : bool = field (metadata = {"frozen" : True })
131
+ elemwise_update : bool = field (metadata = {"frozen" : True })
132
+ _untuned_settings : dict [str , np .ndarray | float ] = field (metadata = {"frozen" : True })
133
+ mode : Any = field (metadata = {"frozen" : True })
134
+
135
+
114
136
class Metropolis (ArrayStepShared ):
115
137
"""Metropolis-Hastings sampling step"""
116
138
117
139
name = "metropolis"
118
140
141
+ default_blocked = False
119
142
stats_dtypes_shapes = {
120
143
"accept" : (np .float64 , []),
121
144
"accepted" : (np .float64 , []),
122
145
"tune" : (bool , []),
123
146
"scaling" : (np .float64 , []),
124
147
}
125
148
149
+ _state_class = MetropolisState
150
+
126
151
def __init__ (
127
152
self ,
128
153
vars = None ,
@@ -346,6 +371,15 @@ def tune(scale, acc_rate):
346
371
)
347
372
348
373
374
+ @dataclass_state
375
+ class BinaryMetropolisState (StepMethodState ):
376
+ tune : bool
377
+ accepted : int
378
+ scaling : float
379
+ tune_interval : int
380
+ steps_until_tune : int
381
+
382
+
349
383
class BinaryMetropolis (ArrayStep ):
350
384
"""Metropolis-Hastings optimized for binary variables
351
385
@@ -375,7 +409,9 @@ class BinaryMetropolis(ArrayStep):
375
409
"p_jump" : (np .float64 , []),
376
410
}
377
411
378
- def __init__ (self , vars , scaling = 1.0 , tune = True , tune_interval = 100 , model = None ):
412
+ _state_class = BinaryMetropolisState
413
+
414
+ def __init__ (self , vars , scaling = 1.0 , tune = True , tune_interval = 100 , model = None , rng = None ):
379
415
model = pm .modelcontext (model )
380
416
381
417
self .scaling = scaling
@@ -389,7 +425,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
389
425
if not all ([v .dtype in pm .discrete_types for v in vars ]):
390
426
raise ValueError ("All variables must be Bernoulli for BinaryMetropolis" )
391
427
392
- super ().__init__ (vars , [model .compile_logp ()])
428
+ super ().__init__ (vars , [model .compile_logp ()], rng = rng )
393
429
394
430
def astep (self , apoint : RaveledVars , * args ) -> tuple [RaveledVars , StatsType ]:
395
431
logp = args [0 ]
@@ -445,6 +481,14 @@ def competence(var):
445
481
return Competence .INCOMPATIBLE
446
482
447
483
484
+ @dataclass_state
485
+ class BinaryGibbsMetropolisState (StepMethodState ):
486
+ tune : bool
487
+ transit_p : int
488
+ shuffle_dims : bool
489
+ order : list
490
+
491
+
448
492
class BinaryGibbsMetropolis (ArrayStep ):
449
493
"""A Metropolis-within-Gibbs step method optimized for binary variables
450
494
@@ -472,7 +516,9 @@ class BinaryGibbsMetropolis(ArrayStep):
472
516
"tune" : (bool , []),
473
517
}
474
518
475
- def __init__ (self , vars , order = "random" , transit_p = 0.8 , model = None ):
519
+ _state_class = BinaryGibbsMetropolisState
520
+
521
+ def __init__ (self , vars , order = "random" , transit_p = 0.8 , model = None , rng = None ):
476
522
model = pm .modelcontext (model )
477
523
478
524
# Doesn't actually tune, but it's required to emit a sampler stat
@@ -498,7 +544,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
498
544
if not all ([v .dtype in pm .discrete_types for v in vars ]):
499
545
raise ValueError ("All variables must be binary for BinaryGibbsMetropolis" )
500
546
501
- super ().__init__ (vars , [model .compile_logp ()])
547
+ super ().__init__ (vars , [model .compile_logp ()], rng = rng )
502
548
503
549
def reset_tuning (self ):
504
550
# There are no tuning parameters in this step method.
@@ -557,6 +603,13 @@ def competence(var):
557
603
return Competence .INCOMPATIBLE
558
604
559
605
606
+ @dataclass_state
607
+ class CategoricalGibbsMetropolisState (StepMethodState ):
608
+ shuffle_dims : bool
609
+ dimcats : list [tuple ]
610
+ tune : bool
611
+
612
+
560
613
class CategoricalGibbsMetropolis (ArrayStep ):
561
614
"""A Metropolis-within-Gibbs step method optimized for categorical variables.
562
615
@@ -573,6 +626,8 @@ class CategoricalGibbsMetropolis(ArrayStep):
573
626
"tune" : (bool , []),
574
627
}
575
628
629
+ _state_class = CategoricalGibbsMetropolisState
630
+
576
631
def __init__ (self , vars , proposal = "uniform" , order = "random" , model = None , rng = None ):
577
632
model = pm .modelcontext (model )
578
633
@@ -728,6 +783,18 @@ def competence(var):
728
783
return Competence .INCOMPATIBLE
729
784
730
785
786
+ @dataclass_state
787
+ class DEMetropolisState (StepMethodState ):
788
+ scaling : np .ndarray
789
+ lamb : float
790
+ tune : str | None
791
+ tune_interval : int
792
+ steps_until_tune : int
793
+ accepted : int
794
+
795
+ mode : Any = field (metadata = {"frozen" : True })
796
+
797
+
731
798
class DEMetropolis (PopulationArrayStepShared ):
732
799
"""
733
800
Differential Evolution Metropolis sampling step.
@@ -778,6 +845,8 @@ class DEMetropolis(PopulationArrayStepShared):
778
845
"lambda" : (np .float64 , []),
779
846
}
780
847
848
+ _state_class = DEMetropolisState
849
+
781
850
def __init__ (
782
851
self ,
783
852
vars = None ,
@@ -789,6 +858,7 @@ def __init__(
789
858
tune_interval = 100 ,
790
859
model = None ,
791
860
mode = None ,
861
+ rng = None ,
792
862
** kwargs ,
793
863
):
794
864
model = pm .modelcontext (model )
@@ -824,7 +894,7 @@ def __init__(
824
894
825
895
shared = pm .make_shared_replacements (initial_values , vars , model )
826
896
self .delta_logp = delta_logp (initial_values , model .logp (), vars , shared )
827
- super ().__init__ (vars , shared )
897
+ super ().__init__ (vars , shared , rng = rng )
828
898
829
899
def astep (self , q0 : RaveledVars ) -> tuple [RaveledVars , StatsType ]:
830
900
point_map_info = q0 .point_map_info
@@ -843,9 +913,11 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
843
913
844
914
# differential evolution proposal
845
915
# select two other chains
846
- ir1 , ir2 = np .random .choice (self .other_chains , 2 , replace = False )
847
- r1 = DictToArrayBijection .map (self .population [ir1 ])
848
- r2 = DictToArrayBijection .map (self .population [ir2 ])
916
+ if self .other_chains is None : # pragma: no cover
917
+ raise RuntimeError ("Population sampler has not been linked to the other chains" )
918
+ ir1 , ir2 = self .rng .choice (self .other_chains , 2 , replace = False )
919
+ r1 = DictToArrayBijection .map (self .population [ir1 ]) # type: ignore
920
+ r2 = DictToArrayBijection .map (self .population [ir2 ]) # type: ignore
849
921
# propose a jump
850
922
q = floatX (q0d + self .lamb * (r1 .data - r2 .data ) + epsilon )
851
923
@@ -872,6 +944,21 @@ def competence(var, has_grad):
872
944
return Competence .COMPATIBLE
873
945
874
946
947
+ @dataclass_state
948
+ class DEMetropolisZState (StepMethodState ):
949
+ scaling : np .ndarray
950
+ lamb : float
951
+ tune : bool
952
+ tune_target : str | None
953
+ tune_interval : int
954
+ steps_until_tune : int
955
+ accepted : int
956
+ _history : list
957
+
958
+ _untuned_settings : dict [str , np .ndarray | float ] = field (metadata = {"frozen" : True })
959
+ mode : Any = field (metadata = {"frozen" : True })
960
+
961
+
875
962
class DEMetropolisZ (ArrayStepShared ):
876
963
"""
877
964
Adaptive Differential Evolution Metropolis sampling step that uses the past to inform jumps.
@@ -925,6 +1012,8 @@ class DEMetropolisZ(ArrayStepShared):
925
1012
"lambda" : (np .float64 , []),
926
1013
}
927
1014
1015
+ _state_class = DEMetropolisZState
1016
+
928
1017
def __init__ (
929
1018
self ,
930
1019
vars = None ,
@@ -937,6 +1026,7 @@ def __init__(
937
1026
tune_drop_fraction : float = 0.9 ,
938
1027
model = None ,
939
1028
mode = None ,
1029
+ rng = None ,
940
1030
** kwargs ,
941
1031
):
942
1032
model = pm .modelcontext (model )
@@ -984,7 +1074,7 @@ def __init__(
984
1074
985
1075
shared = pm .make_shared_replacements (initial_values , vars , model )
986
1076
self .delta_logp = delta_logp (initial_values , model .logp (), vars , shared )
987
- super ().__init__ (vars , shared )
1077
+ super ().__init__ (vars , shared , rng = rng )
988
1078
989
1079
def reset_tuning (self ):
990
1080
"""Resets the tuned sampler parameters and history to their initial values."""
0 commit comments