Skip to content

Commit 3ab1c00

Browse files
authored
Port MLDA to v4 - Refactored to not use trace continuation (#5095)
* Refactored MLDA proposal to not use trace continuation * Ran isort on mlda.py to sort imports * Removed xfail markers from MLDA tests in test_types.py * Reduced tolerance in test_aem_mu_sigma to accomodate single precision floats
1 parent f26845d commit 3ab1c00

File tree

3 files changed

+174
-218
lines changed

3 files changed

+174
-218
lines changed

pymc/step_methods/mlda.py

Lines changed: 53 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import warnings
1717

18-
from typing import List, Optional, Type, Union
18+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1919

2020
import aesara
2121
import arviz as az
@@ -25,7 +25,8 @@
2525

2626
import pymc as pm
2727

28-
from pymc.blocking import DictToArrayBijection
28+
from pymc.aesaraf import compile_rv_inplace
29+
from pymc.blocking import DictToArrayBijection, RaveledVars
2930
from pymc.model import Model, Point
3031
from pymc.step_methods.arraystep import ArrayStepShared, Competence, metrop_select
3132
from pymc.step_methods.compound import CompoundStep
@@ -66,20 +67,20 @@ def __init__(self, *args, **kwargs):
6667
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above
6768

6869
# extract some necessary variables
69-
value_vars = kwargs.get("vars", None)
70-
if value_vars is None:
71-
value_vars = model.value_vars
70+
vars = kwargs.get("vars", None)
71+
if vars is None:
72+
vars = model.value_vars
7273
else:
73-
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
74-
value_vars = pm.inputvars(value_vars)
75-
shared = pm.make_shared_replacements(initial_values, value_vars, model)
74+
vars = [model.rvs_to_values.get(var, var) for var in vars]
75+
vars = pm.inputvars(vars)
76+
shared = pm.make_shared_replacements(initial_values, vars, model)
7677

7778
# call parent class __init__
7879
super().__init__(*args, **kwargs)
7980

8081
# modify the delta function and point to model if VR is used
8182
if self.mlda_variance_reduction:
82-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
83+
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
8384
self.model = model
8485

8586
def reset_tuning(self):
@@ -136,20 +137,20 @@ def __init__(self, *args, **kwargs):
136137
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above
137138

138139
# extract some necessary variables
139-
value_vars = kwargs.get("vars", None)
140-
if value_vars is None:
141-
value_vars = model.value_vars
140+
vars = kwargs.get("vars", None)
141+
if vars is None:
142+
vars = model.value_vars
142143
else:
143-
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
144-
value_vars = pm.inputvars(value_vars)
145-
shared = pm.make_shared_replacements(initial_values, value_vars, model)
144+
vars = [model.rvs_to_values.get(var, var) for var in vars]
145+
vars = pm.inputvars(vars)
146+
shared = pm.make_shared_replacements(initial_values, vars, model)
146147

147148
# call parent class __init__
148149
super().__init__(*args, **kwargs)
149150

150151
# modify the delta function and point to model if VR is used
151152
if self.mlda_variance_reduction:
152-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
153+
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
153154
self.model = model
154155

155156
def reset_tuning(self):
@@ -363,7 +364,7 @@ class MLDA(ArrayStepShared):
363364
def __init__(
364365
self,
365366
coarse_models: List[Model],
366-
value_vars: Optional[list] = None,
367+
vars: Optional[list] = None,
367368
base_sampler="DEMetropolisZ",
368369
base_S: Optional = None,
369370
base_proposal_dist: Optional[Type[Proposal]] = None,
@@ -386,10 +387,6 @@ def __init__(
386387
# this variable is used to identify MLDA objects which are
387388
# not in the finest level (i.e. child MLDA objects)
388389
self.is_child = kwargs.get("is_child", False)
389-
if not self.is_child:
390-
warnings.warn(
391-
"The MLDA implementation in PyMC is still immature. You should be particularly critical of its results."
392-
)
393390

394391
if not isinstance(coarse_models, list):
395392
raise ValueError("MLDA step method cannot use coarse_models if it is not a list")
@@ -546,20 +543,20 @@ def __init__(
546543
self.mode = mode
547544

548545
# Process model variables
549-
if value_vars is None:
550-
value_vars = model.value_vars
546+
if vars is None:
547+
vars = model.value_vars
551548
else:
552-
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
553-
value_vars = pm.inputvars(value_vars)
554-
self.vars = value_vars
549+
vars = [model.rvs_to_values.get(var, var) for var in vars]
550+
vars = pm.inputvars(vars)
551+
self.vars = vars
555552
self.var_names = [var.name for var in self.vars]
556553

557554
self.accepted = 0
558555

559556
# Construct Aesara function for current-level model likelihood
560557
# (for use in acceptance)
561-
shared = pm.make_shared_replacements(initial_values, value_vars, model)
562-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
558+
shared = pm.make_shared_replacements(initial_values, vars, model)
559+
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
563560

564561
# Construct Aesara function for below-level model likelihood
565562
# (for use in acceptance)
@@ -571,7 +568,7 @@ def __init__(
571568
initial_values, model_below.logpt, vars_below, shared_below
572569
)
573570

574-
super().__init__(value_vars, shared)
571+
super().__init__(vars, shared)
575572

576573
# initialise complete step method hierarchy
577574
if self.num_levels == 2:
@@ -643,7 +640,7 @@ def __init__(
643640

644641
# MLDA sampler in some intermediate level, targeting self.model_below
645642
self.step_method_below = pm.MLDA(
646-
value_vars=vars_below,
643+
vars=vars_below,
647644
base_S=self.base_S,
648645
base_sampler=self.base_sampler,
649646
base_proposal_dist=self.base_proposal_dist,
@@ -715,7 +712,7 @@ def __init__(
715712
if self.store_Q_fine and not self.is_child:
716713
self.stats_dtypes[0][f"Q_{self.num_levels - 1}"] = object
717714

718-
def astep(self, q0):
715+
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
719716
"""One MLDA step, given current sample q0"""
720717
# Check if the tuning flag has been changed and if yes,
721718
# change the proposal's tuning flag and reset self.accepted
@@ -730,10 +727,6 @@ def astep(self, q0):
730727
method.tune = self.tune
731728
self.accepted = 0
732729

733-
# Convert current sample from numpy array ->
734-
# dict before feeding to proposal
735-
q0_dict = DictToArrayBijection.rmap(q0)
736-
737730
# Set subchain_selection (which sample from the coarse chain
738731
# is passed as a proposal to the fine chain). If variance
739732
# reduction is used, a random sample is selected as proposal.
@@ -747,14 +740,13 @@ def astep(self, q0):
747740

748741
# Call the recursive DA proposal to get proposed sample
749742
# and convert dict -> numpy array
750-
pre_q = self.proposal_dist(q0_dict)
751-
q = DictToArrayBijection.map(pre_q)
743+
q = self.proposal_dist(q0)
752744

753745
# Evaluate MLDA acceptance log-ratio
754746
# If proposed sample from lower levels is the same as current one,
755747
# do not calculate likelihood, just set accept to 0.0
756748
if (q.data == q0.data).all():
757-
accept = np.float(0.0)
749+
accept = np.float64(0.0)
758750
skipped_logp = True
759751
else:
760752
accept = self.delta_logp(q.data, q0.data) + self.delta_logp_below(q0.data, q.data)
@@ -811,22 +803,22 @@ def astep(self, q0):
811803
if isinstance(self.step_method_below, MLDA):
812804
self.base_tuning_stats = self.step_method_below.base_tuning_stats
813805
elif isinstance(self.step_method_below, MetropolisMLDA):
814-
self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling[0]})
806+
self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling})
815807
elif isinstance(self.step_method_below, DEMetropolisZMLDA):
816808
self.base_tuning_stats.append(
817809
{
818-
"base_scaling": self.step_method_below.scaling[0],
810+
"base_scaling": self.step_method_below.scaling,
819811
"base_lambda": self.step_method_below.lamb,
820812
}
821813
)
822814
elif isinstance(self.step_method_below, CompoundStep):
823815
# Below method is CompoundStep
824816
for method in self.step_method_below.methods:
825817
if isinstance(method, MetropolisMLDA):
826-
self.base_tuning_stats.append({"base_scaling": method.scaling[0]})
818+
self.base_tuning_stats.append({"base_scaling": method.scaling})
827819
elif isinstance(method, DEMetropolisZMLDA):
828820
self.base_tuning_stats.append(
829-
{"base_scaling": method.scaling[0], "base_lambda": method.lamb}
821+
{"base_scaling": method.scaling, "base_lambda": method.lamb}
830822
)
831823

832824
return q_new, [stats] + self.base_tuning_stats
@@ -970,7 +962,7 @@ def delta_logp_inverse(point, logp, vars, shared):
970962

971963
logp1 = pm.CallableTensor(logp0)(inarray1)
972964

973-
f = aesara.function([inarray1, inarray0], -logp0 + logp1)
965+
f = compile_rv_inplace([inarray1, inarray0], -logp0 + logp1)
974966
f.trust_input = True
975967
return f
976968

@@ -1015,9 +1007,6 @@ def subsample(
10151007
trace=None,
10161008
tune=0,
10171009
model=None,
1018-
random_seed=None,
1019-
callback=None,
1020-
**kwargs,
10211010
):
10221011
"""
10231012
A stripped down version of sample(), which is called only
@@ -1032,19 +1021,10 @@ def subsample(
10321021
model = pm.modelcontext(model)
10331022
chain = 0
10341023
random_seed = np.random.randint(2 ** 30)
1035-
1036-
if start is not None:
1037-
pm.sampling._check_start_shape(model, start)
1038-
else:
1039-
start = {}
1024+
callback = None
10401025

10411026
draws += tune
10421027

1043-
step = pm.sampling.assign_step_methods(model, step, step_kwargs=kwargs)
1044-
1045-
if isinstance(step, list):
1046-
step = CompoundStep(step)
1047-
10481028
sampling = pm.sampling._iter_sample(
10491029
draws, step, start, trace, chain, tune, model, random_seed, callback
10501030
)
@@ -1086,9 +1066,8 @@ def __init__(
10861066
self.subsampling_rate = subsampling_rate
10871067
self.subchain_selection = None
10881068
self.tuning_end_trigger = True
1089-
self.trace = None
10901069

1091-
def __call__(self, q0_dict: dict) -> dict:
1070+
def __call__(self, q0: RaveledVars) -> RaveledVars:
10921071
"""Returns proposed sample given the current sample
10931072
in dictionary form (q0_dict)."""
10941073

@@ -1097,6 +1076,10 @@ def __call__(self, q0_dict: dict) -> dict:
10971076
_log = logging.getLogger("pymc")
10981077
_log.setLevel(logging.ERROR)
10991078

1079+
# Convert current sample from RaveledVars ->
1080+
# dict before feeding to subsample.
1081+
q0_dict = DictToArrayBijection.rmap(q0)
1082+
11001083
with self.model_below:
11011084
# Check if the tuning flag has been set to False
11021085
# in which case tuning is stopped. The flag is set
@@ -1106,11 +1089,10 @@ def __call__(self, q0_dict: dict) -> dict:
11061089

11071090
if self.tune:
11081091
# Subsample in tuning mode
1109-
self.trace = subsample(
1092+
trace = subsample(
11101093
draws=0,
11111094
step=self.step_method_below,
11121095
start=q0_dict,
1113-
trace=self.trace,
11141096
tune=self.subsampling_rate,
11151097
)
11161098
else:
@@ -1122,11 +1104,11 @@ def __call__(self, q0_dict: dict) -> dict:
11221104
self.step_method_below.tuning_end_trigger = True
11231105
self.tuning_end_trigger = False
11241106

1125-
self.trace = subsample(
1107+
trace = subsample(
11261108
draws=self.subsampling_rate,
11271109
step=self.step_method_below,
11281110
start=q0_dict,
1129-
trace=self.trace,
1111+
tune=0,
11301112
)
11311113

11321114
# set logging back to normal
@@ -1135,7 +1117,13 @@ def __call__(self, q0_dict: dict) -> dict:
11351117
# return sample with index self.subchain_selection from the generated
11361118
# sequence of length self.subsampling_rate. The index is set within
11371119
# MLDA's astep() function
1138-
new_point = self.trace.point(-self.subsampling_rate + self.subchain_selection)
1139-
new_point = Point(new_point, model=self.model_below, filter_model_vars=True)
1120+
q_dict = trace.point(self.subchain_selection)
1121+
1122+
# Make sure output dict is ordered the same way as the input dict.
1123+
q_dict = Point(
1124+
{key: q_dict[key] for key in q0_dict.keys()},
1125+
model=self.model_below,
1126+
filter_model_vars=True,
1127+
)
11401128

1141-
return new_point
1129+
return DictToArrayBijection.map(q_dict)

0 commit comments

Comments
 (0)