15
15
import logging
16
16
import warnings
17
17
18
- from typing import List , Optional , Type , Union
18
+ from typing import Any , Dict , List , Optional , Tuple , Type , Union
19
19
20
20
import aesara
21
21
import arviz as az
25
25
26
26
import pymc as pm
27
27
28
- from pymc .blocking import DictToArrayBijection
28
+ from pymc .aesaraf import compile_rv_inplace
29
+ from pymc .blocking import DictToArrayBijection , RaveledVars
29
30
from pymc .model import Model , Point
30
31
from pymc .step_methods .arraystep import ArrayStepShared , Competence , metrop_select
31
32
from pymc .step_methods .compound import CompoundStep
@@ -66,20 +67,20 @@ def __init__(self, *args, **kwargs):
66
67
self .Q_reg = [np .nan ] * self .mlda_subsampling_rate_above
67
68
68
69
# extract some necessary variables
69
- value_vars = kwargs .get ("vars" , None )
70
- if value_vars is None :
71
- value_vars = model .value_vars
70
+ vars = kwargs .get ("vars" , None )
71
+ if vars is None :
72
+ vars = model .value_vars
72
73
else :
73
- value_vars = [model .rvs_to_values .get (var , var ) for var in value_vars ]
74
- value_vars = pm .inputvars (value_vars )
75
- shared = pm .make_shared_replacements (initial_values , value_vars , model )
74
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
75
+ vars = pm .inputvars (vars )
76
+ shared = pm .make_shared_replacements (initial_values , vars , model )
76
77
77
78
# call parent class __init__
78
79
super ().__init__ (* args , ** kwargs )
79
80
80
81
# modify the delta function and point to model if VR is used
81
82
if self .mlda_variance_reduction :
82
- self .delta_logp = delta_logp_inverse (initial_values , model .logpt , value_vars , shared )
83
+ self .delta_logp = delta_logp_inverse (initial_values , model .logpt , vars , shared )
83
84
self .model = model
84
85
85
86
def reset_tuning (self ):
@@ -136,20 +137,20 @@ def __init__(self, *args, **kwargs):
136
137
self .Q_reg = [np .nan ] * self .mlda_subsampling_rate_above
137
138
138
139
# extract some necessary variables
139
- value_vars = kwargs .get ("vars" , None )
140
- if value_vars is None :
141
- value_vars = model .value_vars
140
+ vars = kwargs .get ("vars" , None )
141
+ if vars is None :
142
+ vars = model .value_vars
142
143
else :
143
- value_vars = [model .rvs_to_values .get (var , var ) for var in value_vars ]
144
- value_vars = pm .inputvars (value_vars )
145
- shared = pm .make_shared_replacements (initial_values , value_vars , model )
144
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
145
+ vars = pm .inputvars (vars )
146
+ shared = pm .make_shared_replacements (initial_values , vars , model )
146
147
147
148
# call parent class __init__
148
149
super ().__init__ (* args , ** kwargs )
149
150
150
151
# modify the delta function and point to model if VR is used
151
152
if self .mlda_variance_reduction :
152
- self .delta_logp = delta_logp_inverse (initial_values , model .logpt , value_vars , shared )
153
+ self .delta_logp = delta_logp_inverse (initial_values , model .logpt , vars , shared )
153
154
self .model = model
154
155
155
156
def reset_tuning (self ):
@@ -363,7 +364,7 @@ class MLDA(ArrayStepShared):
363
364
def __init__ (
364
365
self ,
365
366
coarse_models : List [Model ],
366
- value_vars : Optional [list ] = None ,
367
+ vars : Optional [list ] = None ,
367
368
base_sampler = "DEMetropolisZ" ,
368
369
base_S : Optional = None ,
369
370
base_proposal_dist : Optional [Type [Proposal ]] = None ,
@@ -386,10 +387,6 @@ def __init__(
386
387
# this variable is used to identify MLDA objects which are
387
388
# not in the finest level (i.e. child MLDA objects)
388
389
self .is_child = kwargs .get ("is_child" , False )
389
- if not self .is_child :
390
- warnings .warn (
391
- "The MLDA implementation in PyMC is still immature. You should be particularly critical of its results."
392
- )
393
390
394
391
if not isinstance (coarse_models , list ):
395
392
raise ValueError ("MLDA step method cannot use coarse_models if it is not a list" )
@@ -546,20 +543,20 @@ def __init__(
546
543
self .mode = mode
547
544
548
545
# Process model variables
549
- if value_vars is None :
550
- value_vars = model .value_vars
546
+ if vars is None :
547
+ vars = model .value_vars
551
548
else :
552
- value_vars = [model .rvs_to_values .get (var , var ) for var in value_vars ]
553
- value_vars = pm .inputvars (value_vars )
554
- self .vars = value_vars
549
+ vars = [model .rvs_to_values .get (var , var ) for var in vars ]
550
+ vars = pm .inputvars (vars )
551
+ self .vars = vars
555
552
self .var_names = [var .name for var in self .vars ]
556
553
557
554
self .accepted = 0
558
555
559
556
# Construct Aesara function for current-level model likelihood
560
557
# (for use in acceptance)
561
- shared = pm .make_shared_replacements (initial_values , value_vars , model )
562
- self .delta_logp = delta_logp_inverse (initial_values , model .logpt , value_vars , shared )
558
+ shared = pm .make_shared_replacements (initial_values , vars , model )
559
+ self .delta_logp = delta_logp_inverse (initial_values , model .logpt , vars , shared )
563
560
564
561
# Construct Aesara function for below-level model likelihood
565
562
# (for use in acceptance)
@@ -571,7 +568,7 @@ def __init__(
571
568
initial_values , model_below .logpt , vars_below , shared_below
572
569
)
573
570
574
- super ().__init__ (value_vars , shared )
571
+ super ().__init__ (vars , shared )
575
572
576
573
# initialise complete step method hierarchy
577
574
if self .num_levels == 2 :
@@ -643,7 +640,7 @@ def __init__(
643
640
644
641
# MLDA sampler in some intermediate level, targeting self.model_below
645
642
self .step_method_below = pm .MLDA (
646
- value_vars = vars_below ,
643
+ vars = vars_below ,
647
644
base_S = self .base_S ,
648
645
base_sampler = self .base_sampler ,
649
646
base_proposal_dist = self .base_proposal_dist ,
@@ -715,7 +712,7 @@ def __init__(
715
712
if self .store_Q_fine and not self .is_child :
716
713
self .stats_dtypes [0 ][f"Q_{ self .num_levels - 1 } " ] = object
717
714
718
- def astep (self , q0 ) :
715
+ def astep (self , q0 : RaveledVars ) -> Tuple [ RaveledVars , List [ Dict [ str , Any ]]] :
719
716
"""One MLDA step, given current sample q0"""
720
717
# Check if the tuning flag has been changed and if yes,
721
718
# change the proposal's tuning flag and reset self.accepted
@@ -730,10 +727,6 @@ def astep(self, q0):
730
727
method .tune = self .tune
731
728
self .accepted = 0
732
729
733
- # Convert current sample from numpy array ->
734
- # dict before feeding to proposal
735
- q0_dict = DictToArrayBijection .rmap (q0 )
736
-
737
730
# Set subchain_selection (which sample from the coarse chain
738
731
# is passed as a proposal to the fine chain). If variance
739
732
# reduction is used, a random sample is selected as proposal.
@@ -747,14 +740,13 @@ def astep(self, q0):
747
740
748
741
# Call the recursive DA proposal to get proposed sample
749
742
# and convert dict -> numpy array
750
- pre_q = self .proposal_dist (q0_dict )
751
- q = DictToArrayBijection .map (pre_q )
743
+ q = self .proposal_dist (q0 )
752
744
753
745
# Evaluate MLDA acceptance log-ratio
754
746
# If proposed sample from lower levels is the same as current one,
755
747
# do not calculate likelihood, just set accept to 0.0
756
748
if (q .data == q0 .data ).all ():
757
- accept = np .float (0.0 )
749
+ accept = np .float64 (0.0 )
758
750
skipped_logp = True
759
751
else :
760
752
accept = self .delta_logp (q .data , q0 .data ) + self .delta_logp_below (q0 .data , q .data )
@@ -811,22 +803,22 @@ def astep(self, q0):
811
803
if isinstance (self .step_method_below , MLDA ):
812
804
self .base_tuning_stats = self .step_method_below .base_tuning_stats
813
805
elif isinstance (self .step_method_below , MetropolisMLDA ):
814
- self .base_tuning_stats .append ({"base_scaling" : self .step_method_below .scaling [ 0 ] })
806
+ self .base_tuning_stats .append ({"base_scaling" : self .step_method_below .scaling })
815
807
elif isinstance (self .step_method_below , DEMetropolisZMLDA ):
816
808
self .base_tuning_stats .append (
817
809
{
818
- "base_scaling" : self .step_method_below .scaling [ 0 ] ,
810
+ "base_scaling" : self .step_method_below .scaling ,
819
811
"base_lambda" : self .step_method_below .lamb ,
820
812
}
821
813
)
822
814
elif isinstance (self .step_method_below , CompoundStep ):
823
815
# Below method is CompoundStep
824
816
for method in self .step_method_below .methods :
825
817
if isinstance (method , MetropolisMLDA ):
826
- self .base_tuning_stats .append ({"base_scaling" : method .scaling [ 0 ] })
818
+ self .base_tuning_stats .append ({"base_scaling" : method .scaling })
827
819
elif isinstance (method , DEMetropolisZMLDA ):
828
820
self .base_tuning_stats .append (
829
- {"base_scaling" : method .scaling [ 0 ] , "base_lambda" : method .lamb }
821
+ {"base_scaling" : method .scaling , "base_lambda" : method .lamb }
830
822
)
831
823
832
824
return q_new , [stats ] + self .base_tuning_stats
@@ -970,7 +962,7 @@ def delta_logp_inverse(point, logp, vars, shared):
970
962
971
963
logp1 = pm .CallableTensor (logp0 )(inarray1 )
972
964
973
- f = aesara . function ([inarray1 , inarray0 ], - logp0 + logp1 )
965
+ f = compile_rv_inplace ([inarray1 , inarray0 ], - logp0 + logp1 )
974
966
f .trust_input = True
975
967
return f
976
968
@@ -1015,9 +1007,6 @@ def subsample(
1015
1007
trace = None ,
1016
1008
tune = 0 ,
1017
1009
model = None ,
1018
- random_seed = None ,
1019
- callback = None ,
1020
- ** kwargs ,
1021
1010
):
1022
1011
"""
1023
1012
A stripped down version of sample(), which is called only
@@ -1032,19 +1021,10 @@ def subsample(
1032
1021
model = pm .modelcontext (model )
1033
1022
chain = 0
1034
1023
random_seed = np .random .randint (2 ** 30 )
1035
-
1036
- if start is not None :
1037
- pm .sampling ._check_start_shape (model , start )
1038
- else :
1039
- start = {}
1024
+ callback = None
1040
1025
1041
1026
draws += tune
1042
1027
1043
- step = pm .sampling .assign_step_methods (model , step , step_kwargs = kwargs )
1044
-
1045
- if isinstance (step , list ):
1046
- step = CompoundStep (step )
1047
-
1048
1028
sampling = pm .sampling ._iter_sample (
1049
1029
draws , step , start , trace , chain , tune , model , random_seed , callback
1050
1030
)
@@ -1086,9 +1066,8 @@ def __init__(
1086
1066
self .subsampling_rate = subsampling_rate
1087
1067
self .subchain_selection = None
1088
1068
self .tuning_end_trigger = True
1089
- self .trace = None
1090
1069
1091
- def __call__ (self , q0_dict : dict ) -> dict :
1070
+ def __call__ (self , q0 : RaveledVars ) -> RaveledVars :
1092
1071
"""Returns proposed sample given the current sample
1093
1072
in dictionary form (q0_dict)."""
1094
1073
@@ -1097,6 +1076,10 @@ def __call__(self, q0_dict: dict) -> dict:
1097
1076
_log = logging .getLogger ("pymc" )
1098
1077
_log .setLevel (logging .ERROR )
1099
1078
1079
+ # Convert current sample from RaveledVars ->
1080
+ # dict before feeding to subsample.
1081
+ q0_dict = DictToArrayBijection .rmap (q0 )
1082
+
1100
1083
with self .model_below :
1101
1084
# Check if the tuning flag has been set to False
1102
1085
# in which case tuning is stopped. The flag is set
@@ -1106,11 +1089,10 @@ def __call__(self, q0_dict: dict) -> dict:
1106
1089
1107
1090
if self .tune :
1108
1091
# Subsample in tuning mode
1109
- self . trace = subsample (
1092
+ trace = subsample (
1110
1093
draws = 0 ,
1111
1094
step = self .step_method_below ,
1112
1095
start = q0_dict ,
1113
- trace = self .trace ,
1114
1096
tune = self .subsampling_rate ,
1115
1097
)
1116
1098
else :
@@ -1122,11 +1104,11 @@ def __call__(self, q0_dict: dict) -> dict:
1122
1104
self .step_method_below .tuning_end_trigger = True
1123
1105
self .tuning_end_trigger = False
1124
1106
1125
- self . trace = subsample (
1107
+ trace = subsample (
1126
1108
draws = self .subsampling_rate ,
1127
1109
step = self .step_method_below ,
1128
1110
start = q0_dict ,
1129
- trace = self . trace ,
1111
+ tune = 0 ,
1130
1112
)
1131
1113
1132
1114
# set logging back to normal
@@ -1135,7 +1117,13 @@ def __call__(self, q0_dict: dict) -> dict:
1135
1117
# return sample with index self.subchain_selection from the generated
1136
1118
# sequence of length self.subsampling_rate. The index is set within
1137
1119
# MLDA's astep() function
1138
- new_point = self .trace .point (- self .subsampling_rate + self .subchain_selection )
1139
- new_point = Point (new_point , model = self .model_below , filter_model_vars = True )
1120
+ q_dict = trace .point (self .subchain_selection )
1121
+
1122
+ # Make sure output dict is ordered the same way as the input dict.
1123
+ q_dict = Point (
1124
+ {key : q_dict [key ] for key in q0_dict .keys ()},
1125
+ model = self .model_below ,
1126
+ filter_model_vars = True ,
1127
+ )
1140
1128
1141
- return new_point
1129
+ return DictToArrayBijection . map ( q_dict )
0 commit comments