@@ -66,22 +66,14 @@ def __init__(self, *args, **kwargs):
66
66
self .Q_last = np .nan
67
67
self .Q_reg = [np .nan ] * self .mlda_subsampling_rate_above
68
68
69
- # extract some necessary variables
70
- vars = kwargs .get ("vars" , None )
71
- if vars is None :
72
- vars = model .value_vars
73
- else :
74
- vars = [model .rvs_to_values .get (var , var ) for var in vars ]
75
- vars = pm .inputvars (vars )
76
- shared = pm .make_shared_replacements (initial_values , vars , model )
77
-
78
69
# call parent class __init__
79
70
super ().__init__ (* args , ** kwargs )
80
71
81
72
# modify the delta function and point to model if VR is used
82
73
if self .mlda_variance_reduction :
83
- self .delta_logp = delta_logp_inverse (initial_values , model .logpt , vars , shared )
84
74
self .model = model
75
+ self .delta_logp_factory = self .delta_logp
76
+ self .delta_logp = lambda q , q0 : - self .delta_logp_factory (q0 , q )
85
77
86
78
def reset_tuning (self ):
87
79
"""
@@ -136,22 +128,14 @@ def __init__(self, *args, **kwargs):
136
128
self .Q_last = np .nan
137
129
self .Q_reg = [np .nan ] * self .mlda_subsampling_rate_above
138
130
139
- # extract some necessary variables
140
- vars = kwargs .get ("vars" , None )
141
- if vars is None :
142
- vars = model .value_vars
143
- else :
144
- vars = [model .rvs_to_values .get (var , var ) for var in vars ]
145
- vars = pm .inputvars (vars )
146
- shared = pm .make_shared_replacements (initial_values , vars , model )
147
-
148
131
# call parent class __init__
149
132
super ().__init__ (* args , ** kwargs )
150
133
151
134
# modify the delta function and point to model if VR is used
152
135
if self .mlda_variance_reduction :
153
- self .delta_logp = delta_logp_inverse (initial_values , model .logpt , vars , shared )
154
136
self .model = model
137
+ self .delta_logp_factory = self .delta_logp
138
+ self .delta_logp = lambda q , q0 : - self .delta_logp_factory (q0 , q )
155
139
156
140
def reset_tuning (self ):
157
141
"""Skips resetting of tuned sampler parameters
@@ -556,7 +540,7 @@ def __init__(
556
540
# Construct Aesara function for current-level model likelihood
557
541
# (for use in acceptance)
558
542
shared = pm .make_shared_replacements (initial_values , vars , model )
559
- self .delta_logp = delta_logp_inverse (initial_values , model .logpt , vars , shared )
543
+ self .delta_logp = delta_logp (initial_values , model .logpt , vars , shared )
560
544
561
545
# Construct Aesara function for below-level model likelihood
562
546
# (for use in acceptance)
@@ -749,7 +733,9 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
749
733
accept = np .float64 (0.0 )
750
734
skipped_logp = True
751
735
else :
752
- accept = self .delta_logp (q .data , q0 .data ) + self .delta_logp_below (q0 .data , q .data )
736
+ # NB! The order and sign of the first term are swapped compared
737
+ # to the convention to make sure the proposal is evaluated last.
738
+ accept = - self .delta_logp (q0 .data , q .data ) + self .delta_logp_below (q0 .data , q .data )
753
739
skipped_logp = False
754
740
755
741
# Accept/reject sample - next sample is stored in q_new
@@ -954,19 +940,6 @@ def update(self, x):
954
940
self .t += 1
955
941
956
942
957
- def delta_logp_inverse (point , logp , vars , shared ):
958
- [logp0 ], inarray0 = pm .join_nonshared_inputs (point , [logp ], vars , shared )
959
-
960
- tensor_type = inarray0 .type
961
- inarray1 = tensor_type ("inarray1" )
962
-
963
- logp1 = pm .CallableTensor (logp0 )(inarray1 )
964
-
965
- f = compile_rv_inplace ([inarray1 , inarray0 ], - logp0 + logp1 )
966
- f .trust_input = True
967
- return f
968
-
969
-
970
943
def extract_Q_estimate (trace , levels ):
971
944
"""
972
945
Returns expectation and standard error of quantity of interest,
0 commit comments