Skip to content

Commit ebf7f2c

Browse files
committed
Refactored MLDA proposal to not use trace continuation
1 parent f26845d commit ebf7f2c

File tree

2 files changed

+174
-216
lines changed

2 files changed

+174
-216
lines changed

pymc/step_methods/mlda.py

Lines changed: 53 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@
1515
import logging
1616
import warnings
1717

18-
from typing import List, Optional, Type, Union
18+
from typing import List, Optional, Type, Union, Dict, Tuple, Any
1919

2020
import aesara
2121
import arviz as az
2222
import numpy as np
2323

2424
from aesara.tensor.sharedvar import TensorSharedVariable
25+
from pymc.aesaraf import compile_rv_inplace
2526

2627
import pymc as pm
2728

28-
from pymc.blocking import DictToArrayBijection
29+
from pymc.blocking import DictToArrayBijection, RaveledVars
2930
from pymc.model import Model, Point
3031
from pymc.step_methods.arraystep import ArrayStepShared, Competence, metrop_select
3132
from pymc.step_methods.compound import CompoundStep
@@ -66,20 +67,20 @@ def __init__(self, *args, **kwargs):
6667
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above
6768

6869
# extract some necessary variables
69-
value_vars = kwargs.get("vars", None)
70-
if value_vars is None:
71-
value_vars = model.value_vars
70+
vars = kwargs.get("vars", None)
71+
if vars is None:
72+
vars = model.value_vars
7273
else:
73-
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
74-
value_vars = pm.inputvars(value_vars)
75-
shared = pm.make_shared_replacements(initial_values, value_vars, model)
74+
vars = [model.rvs_to_values.get(var, var) for var in vars]
75+
vars = pm.inputvars(vars)
76+
shared = pm.make_shared_replacements(initial_values, vars, model)
7677

7778
# call parent class __init__
7879
super().__init__(*args, **kwargs)
7980

8081
# modify the delta function and point to model if VR is used
8182
if self.mlda_variance_reduction:
82-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
83+
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
8384
self.model = model
8485

8586
def reset_tuning(self):
@@ -136,20 +137,20 @@ def __init__(self, *args, **kwargs):
136137
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above
137138

138139
# extract some necessary variables
139-
value_vars = kwargs.get("vars", None)
140-
if value_vars is None:
141-
value_vars = model.value_vars
140+
vars = kwargs.get("vars", None)
141+
if vars is None:
142+
vars = model.value_vars
142143
else:
143-
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
144-
value_vars = pm.inputvars(value_vars)
145-
shared = pm.make_shared_replacements(initial_values, value_vars, model)
144+
vars = [model.rvs_to_values.get(var, var) for var in vars]
145+
vars = pm.inputvars(vars)
146+
shared = pm.make_shared_replacements(initial_values, vars, model)
146147

147148
# call parent class __init__
148149
super().__init__(*args, **kwargs)
149150

150151
# modify the delta function and point to model if VR is used
151152
if self.mlda_variance_reduction:
152-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
153+
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
153154
self.model = model
154155

155156
def reset_tuning(self):
@@ -363,7 +364,7 @@ class MLDA(ArrayStepShared):
363364
def __init__(
364365
self,
365366
coarse_models: List[Model],
366-
value_vars: Optional[list] = None,
367+
vars: Optional[list] = None,
367368
base_sampler="DEMetropolisZ",
368369
base_S: Optional = None,
369370
base_proposal_dist: Optional[Type[Proposal]] = None,
@@ -386,10 +387,6 @@ def __init__(
386387
# this variable is used to identify MLDA objects which are
387388
# not in the finest level (i.e. child MLDA objects)
388389
self.is_child = kwargs.get("is_child", False)
389-
if not self.is_child:
390-
warnings.warn(
391-
"The MLDA implementation in PyMC is still immature. You should be particularly critical of its results."
392-
)
393390

394391
if not isinstance(coarse_models, list):
395392
raise ValueError("MLDA step method cannot use coarse_models if it is not a list")
@@ -546,20 +543,20 @@ def __init__(
546543
self.mode = mode
547544

548545
# Process model variables
549-
if value_vars is None:
550-
value_vars = model.value_vars
546+
if vars is None:
547+
vars = model.value_vars
551548
else:
552-
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
553-
value_vars = pm.inputvars(value_vars)
554-
self.vars = value_vars
549+
vars = [model.rvs_to_values.get(var, var) for var in vars]
550+
vars = pm.inputvars(vars)
551+
self.vars = vars
555552
self.var_names = [var.name for var in self.vars]
556553

557554
self.accepted = 0
558555

559556
# Construct Aesara function for current-level model likelihood
560557
# (for use in acceptance)
561-
shared = pm.make_shared_replacements(initial_values, value_vars, model)
562-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
558+
shared = pm.make_shared_replacements(initial_values, vars, model)
559+
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
563560

564561
# Construct Aesara function for below-level model likelihood
565562
# (for use in acceptance)
@@ -571,7 +568,7 @@ def __init__(
571568
initial_values, model_below.logpt, vars_below, shared_below
572569
)
573570

574-
super().__init__(value_vars, shared)
571+
super().__init__(vars, shared)
575572

576573
# initialise complete step method hierarchy
577574
if self.num_levels == 2:
@@ -643,7 +640,7 @@ def __init__(
643640

644641
# MLDA sampler in some intermediate level, targeting self.model_below
645642
self.step_method_below = pm.MLDA(
646-
value_vars=vars_below,
643+
vars=vars_below,
647644
base_S=self.base_S,
648645
base_sampler=self.base_sampler,
649646
base_proposal_dist=self.base_proposal_dist,
@@ -715,7 +712,7 @@ def __init__(
715712
if self.store_Q_fine and not self.is_child:
716713
self.stats_dtypes[0][f"Q_{self.num_levels - 1}"] = object
717714

718-
def astep(self, q0):
715+
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
719716
"""One MLDA step, given current sample q0"""
720717
# Check if the tuning flag has been changed and if yes,
721718
# change the proposal's tuning flag and reset self.accepted
@@ -730,10 +727,6 @@ def astep(self, q0):
730727
method.tune = self.tune
731728
self.accepted = 0
732729

733-
# Convert current sample from numpy array ->
734-
# dict before feeding to proposal
735-
q0_dict = DictToArrayBijection.rmap(q0)
736-
737730
# Set subchain_selection (which sample from the coarse chain
738731
# is passed as a proposal to the fine chain). If variance
739732
# reduction is used, a random sample is selected as proposal.
@@ -747,14 +740,13 @@ def astep(self, q0):
747740

748741
# Call the recursive DA proposal to get proposed sample
749742
# and convert dict -> numpy array
750-
pre_q = self.proposal_dist(q0_dict)
751-
q = DictToArrayBijection.map(pre_q)
743+
q = self.proposal_dist(q0)
752744

753745
# Evaluate MLDA acceptance log-ratio
754746
# If proposed sample from lower levels is the same as current one,
755747
# do not calculate likelihood, just set accept to 0.0
756748
if (q.data == q0.data).all():
757-
accept = np.float(0.0)
749+
accept = np.float64(0.0)
758750
skipped_logp = True
759751
else:
760752
accept = self.delta_logp(q.data, q0.data) + self.delta_logp_below(q0.data, q.data)
@@ -811,22 +803,22 @@ def astep(self, q0):
811803
if isinstance(self.step_method_below, MLDA):
812804
self.base_tuning_stats = self.step_method_below.base_tuning_stats
813805
elif isinstance(self.step_method_below, MetropolisMLDA):
814-
self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling[0]})
806+
self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling})
815807
elif isinstance(self.step_method_below, DEMetropolisZMLDA):
816808
self.base_tuning_stats.append(
817809
{
818-
"base_scaling": self.step_method_below.scaling[0],
810+
"base_scaling": self.step_method_below.scaling,
819811
"base_lambda": self.step_method_below.lamb,
820812
}
821813
)
822814
elif isinstance(self.step_method_below, CompoundStep):
823815
# Below method is CompoundStep
824816
for method in self.step_method_below.methods:
825817
if isinstance(method, MetropolisMLDA):
826-
self.base_tuning_stats.append({"base_scaling": method.scaling[0]})
818+
self.base_tuning_stats.append({"base_scaling": method.scaling})
827819
elif isinstance(method, DEMetropolisZMLDA):
828820
self.base_tuning_stats.append(
829-
{"base_scaling": method.scaling[0], "base_lambda": method.lamb}
821+
{"base_scaling": method.scaling, "base_lambda": method.lamb}
830822
)
831823

832824
return q_new, [stats] + self.base_tuning_stats
@@ -970,7 +962,7 @@ def delta_logp_inverse(point, logp, vars, shared):
970962

971963
logp1 = pm.CallableTensor(logp0)(inarray1)
972964

973-
f = aesara.function([inarray1, inarray0], -logp0 + logp1)
965+
f = compile_rv_inplace([inarray1, inarray0], -logp0 + logp1)
974966
f.trust_input = True
975967
return f
976968

@@ -1015,9 +1007,6 @@ def subsample(
10151007
trace=None,
10161008
tune=0,
10171009
model=None,
1018-
random_seed=None,
1019-
callback=None,
1020-
**kwargs,
10211010
):
10221011
"""
10231012
A stripped down version of sample(), which is called only
@@ -1032,19 +1021,10 @@ def subsample(
10321021
model = pm.modelcontext(model)
10331022
chain = 0
10341023
random_seed = np.random.randint(2 ** 30)
1035-
1036-
if start is not None:
1037-
pm.sampling._check_start_shape(model, start)
1038-
else:
1039-
start = {}
1024+
callback = None
10401025

10411026
draws += tune
10421027

1043-
step = pm.sampling.assign_step_methods(model, step, step_kwargs=kwargs)
1044-
1045-
if isinstance(step, list):
1046-
step = CompoundStep(step)
1047-
10481028
sampling = pm.sampling._iter_sample(
10491029
draws, step, start, trace, chain, tune, model, random_seed, callback
10501030
)
@@ -1086,9 +1066,8 @@ def __init__(
10861066
self.subsampling_rate = subsampling_rate
10871067
self.subchain_selection = None
10881068
self.tuning_end_trigger = True
1089-
self.trace = None
10901069

1091-
def __call__(self, q0_dict: dict) -> dict:
1070+
def __call__(self, q0: RaveledVars) -> RaveledVars:
10921071
"""Returns proposed sample given the current sample
10931072
in dictionary form (q0_dict)."""
10941073

@@ -1097,6 +1076,10 @@ def __call__(self, q0_dict: dict) -> dict:
10971076
_log = logging.getLogger("pymc")
10981077
_log.setLevel(logging.ERROR)
10991078

1079+
# Convert current sample from RaveledVars ->
1080+
# dict before feeding to subsample.
1081+
q0_dict = DictToArrayBijection.rmap(q0)
1082+
11001083
with self.model_below:
11011084
# Check if the tuning flag has been set to False
11021085
# in which case tuning is stopped. The flag is set
@@ -1106,11 +1089,10 @@ def __call__(self, q0_dict: dict) -> dict:
11061089

11071090
if self.tune:
11081091
# Subsample in tuning mode
1109-
self.trace = subsample(
1092+
trace = subsample(
11101093
draws=0,
11111094
step=self.step_method_below,
11121095
start=q0_dict,
1113-
trace=self.trace,
11141096
tune=self.subsampling_rate,
11151097
)
11161098
else:
@@ -1122,11 +1104,11 @@ def __call__(self, q0_dict: dict) -> dict:
11221104
self.step_method_below.tuning_end_trigger = True
11231105
self.tuning_end_trigger = False
11241106

1125-
self.trace = subsample(
1107+
trace = subsample(
11261108
draws=self.subsampling_rate,
11271109
step=self.step_method_below,
11281110
start=q0_dict,
1129-
trace=self.trace,
1111+
tune=0,
11301112
)
11311113

11321114
# set logging back to normal
@@ -1135,7 +1117,13 @@ def __call__(self, q0_dict: dict) -> dict:
11351117
# return sample with index self.subchain_selection from the generated
11361118
# sequence of length self.subsampling_rate. The index is set within
11371119
# MLDA's astep() function
1138-
new_point = self.trace.point(-self.subsampling_rate + self.subchain_selection)
1139-
new_point = Point(new_point, model=self.model_below, filter_model_vars=True)
1120+
q_dict = trace.point(self.subchain_selection)
1121+
1122+
# Make sure output dict is ordered the same way as the input dict.
1123+
q_dict = Point(
1124+
{key: q_dict[key] for key in q0_dict.keys()},
1125+
model=self.model_below,
1126+
filter_model_vars=True,
1127+
)
11401128

1141-
return new_point
1129+
return DictToArrayBijection.map(q_dict)

0 commit comments

Comments
 (0)