diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 02e77d6a14..a96f68068d 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -168,10 +168,12 @@ def __init__( vars = model.value_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] + vars = pm.inputvars(vars) + initial_values_shape = [initial_values[v.name].shape for v in vars] if S is None: - S = np.ones(sum(initial_values[v.name].size for v in vars)) + S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape))) if proposal_dist is not None: self.proposal_dist = proposal_dist(S) @@ -186,7 +188,6 @@ def __init__( self.tune = tune self.tune_interval = tune_interval self.steps_until_tune = tune_interval - self.accepted = 0 # Determine type of variables self.discrete = np.concatenate( @@ -195,11 +196,33 @@ def __init__( self.any_discrete = self.discrete.any() self.all_discrete = self.discrete.all() - # remember initial settings before tuning so they can be reset - self._untuned_settings = dict( - scaling=self.scaling, steps_until_tune=tune_interval, accepted=self.accepted + # Metropolis will try to handle one batched dimension at a time This, however, + # is not safe for discrete multivariate distributions (looking at you Multinomial), + # due to high dependency among the support dimensions. For continuous multivariate + # distributions we assume they are being transformed in a way that makes each + # dimension semi-independent. + is_scalar = len(initial_values_shape) == 1 and initial_values_shape[0] == () + self.elemwise_update = not ( + is_scalar + or ( + self.any_discrete + and max(getattr(model.values_to_rvs[var].owner.op, "ndim_supp", 1) for var in vars) + > 0 + ) ) + if self.elemwise_update: + dims = int(sum(np.prod(ivs) for ivs in initial_values_shape)) + else: + dims = 1 + self.enum_dims = np.arange(dims, dtype=int) + self.accept_rate_iter = np.zeros(dims, dtype=float) + self.accepted_iter = np.zeros(dims, dtype=bool) + self.accepted_sum = np.zeros(dims, dtype=int) + # remember initial settings before tuning so they can be reset + self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval) + + # TODO: This is not being used when compiling the logp function! self.mode = mode shared = pm.make_shared_replacements(initial_values, vars, model) @@ -210,6 +233,7 @@ def reset_tuning(self): """Resets the tuned sampler parameters to their initial values.""" for attr, initial_value in self._untuned_settings.items(): setattr(self, attr, initial_value) + self.accepted_sum[:] = 0 return def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: @@ -219,10 +243,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: if not self.steps_until_tune and self.tune: # Tune scaling parameter - self.scaling = tune(self.scaling, self.accepted / float(self.tune_interval)) + self.scaling = tune(self.scaling, self.accepted_sum / float(self.tune_interval)) # Reset counter self.steps_until_tune = self.tune_interval - self.accepted = 0 + self.accepted_sum[:] = 0 delta = self.proposal_dist() * self.scaling @@ -237,23 +261,36 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: else: q = floatX(q0 + delta) - accept = self.delta_logp(q, q0) - q_new, accepted = metrop_select(accept, q, q0) - - self.accepted += accepted + if self.elemwise_update: + q_temp = q0.copy() + # Shuffle order of updates (probably we don't need to do this in every step) + np.random.shuffle(self.enum_dims) + for i in self.enum_dims: + q_temp[i] = q[i] + accept_rate_i = self.delta_logp(q_temp, q0) + q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0) + q_temp[i] = q_temp_[i] + self.accept_rate_iter[i] = accept_rate_i + self.accepted_iter[i] = accepted_i + self.accepted_sum[i] += accepted_i + q = q_temp + else: + accept_rate = self.delta_logp(q, q0) + q, accepted = metrop_select(accept_rate, q, q0) + self.accept_rate_iter = accept_rate + self.accepted_iter = accepted + self.accepted_sum += accepted self.steps_until_tune -= 1 stats = { "tune": self.tune, - "scaling": self.scaling, - "accept": np.exp(accept), - "accepted": accepted, + "scaling": np.mean(self.scaling), + "accept": np.mean(np.exp(self.accept_rate_iter)), + "accepted": np.mean(self.accepted_iter), } - q_new = RaveledVars(q_new, point_map_info) - - return q_new, [stats] + return RaveledVars(q, point_map_info), [stats] @staticmethod def competence(var, has_grad): @@ -275,26 +312,38 @@ def tune(scale, acc_rate): >0.95 x 10 """ - if acc_rate < 0.001: + return scale * np.where( + acc_rate < 0.001, # reduce by 90 percent - return scale * 0.1 - elif acc_rate < 0.05: - # reduce by 50 percent - return scale * 0.5 - elif acc_rate < 0.2: - # reduce by ten percent - return scale * 0.9 - elif acc_rate > 0.95: - # increase by factor of ten - return scale * 10.0 - elif acc_rate > 0.75: - # increase by double - return scale * 2.0 - elif acc_rate > 0.5: - # increase by ten percent - return scale * 1.1 - - return scale + 0.1, + np.where( + acc_rate < 0.05, + # reduce by 50 percent + 0.5, + np.where( + acc_rate < 0.2, + # reduce by ten percent + 0.9, + np.where( + acc_rate > 0.95, + # increase by factor of ten + 10.0, + np.where( + acc_rate > 0.75, + # increase by double + 2.0, + np.where( + acc_rate > 0.5, + # increase by ten percent + 1.1, + # Do not change + 1.0, + ), + ), + ), + ), + ), + ) class BinaryMetropolis(ArrayStep): diff --git a/pymc/step_methods/mlda.py b/pymc/step_methods/mlda.py index 873bd3ea1a..441890ebbd 100644 --- a/pymc/step_methods/mlda.py +++ b/pymc/step_methods/mlda.py @@ -787,11 +787,11 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: if isinstance(self.step_method_below, MLDA): self.base_tuning_stats = self.step_method_below.base_tuning_stats elif isinstance(self.step_method_below, MetropolisMLDA): - self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling}) + self.base_tuning_stats.append({"base_scaling": np.mean(self.step_method_below.scaling)}) elif isinstance(self.step_method_below, DEMetropolisZMLDA): self.base_tuning_stats.append( { - "base_scaling": self.step_method_below.scaling, + "base_scaling": np.mean(self.step_method_below.scaling), "base_lambda": self.step_method_below.lamb, } ) @@ -799,10 +799,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: # Below method is CompoundStep for method in self.step_method_below.methods: if isinstance(method, MetropolisMLDA): - self.base_tuning_stats.append({"base_scaling": method.scaling}) + self.base_tuning_stats.append({"base_scaling": np.mean(method.scaling)}) elif isinstance(method, DEMetropolisZMLDA): self.base_tuning_stats.append( - {"base_scaling": method.scaling, "base_lambda": method.lamb} + {"base_scaling": np.mean(method.scaling), "base_lambda": method.lamb} ) return q_new, [stats] + self.base_tuning_stats diff --git a/pymc/tests/test_step.py b/pymc/tests/test_step.py index 9df55f57b6..f23f9bcf40 100644 --- a/pymc/tests/test_step.py +++ b/pymc/tests/test_step.py @@ -35,7 +35,9 @@ Beta, Binomial, Categorical, + Dirichlet, HalfNormal, + Multinomial, MvNormal, Normal, ) @@ -174,33 +176,6 @@ def test_step_categorical(self, proposal): self.check_stat(check, idata, step.__class__.__name__) -class TestMetropolisProposal: - def test_proposal_choice(self): - with aesara.config.change_flags(mode=fast_unstable_sampling_mode): - _, model, _ = mv_simple() - with model: - initial_point = model.initial_point() - initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) - - s = np.ones(initial_point_size) - sampler = Metropolis(S=s) - assert isinstance(sampler.proposal_dist, NormalProposal) - s = np.diag(s) - sampler = Metropolis(S=s) - assert isinstance(sampler.proposal_dist, MultivariateNormalProposal) - s[0, 0] = -s[0, 0] - with pytest.raises(np.linalg.LinAlgError): - sampler = Metropolis(S=s) - - def test_mv_proposal(self): - np.random.seed(42) - cov = np.random.randn(5, 5) - cov = cov.dot(cov.T) - prop = MultivariateNormalProposal(cov) - samples = np.array([prop() for _ in range(10000)]) - npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2) - - class TestCompoundStep: samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis) @@ -383,6 +358,31 @@ def test_parallelized_chains_are_random(self): class TestMetropolis: + def test_proposal_choice(self): + with aesara.config.change_flags(mode=fast_unstable_sampling_mode): + _, model, _ = mv_simple() + with model: + initial_point = model.initial_point() + initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) + + s = np.ones(initial_point_size) + sampler = Metropolis(S=s) + assert isinstance(sampler.proposal_dist, NormalProposal) + s = np.diag(s) + sampler = Metropolis(S=s) + assert isinstance(sampler.proposal_dist, MultivariateNormalProposal) + s[0, 0] = -s[0, 0] + with pytest.raises(np.linalg.LinAlgError): + sampler = Metropolis(S=s) + + def test_mv_proposal(self): + np.random.seed(42) + cov = np.random.randn(5, 5) + cov = cov.dot(cov.T) + prop = MultivariateNormalProposal(cov) + samples = np.array([prop() for _ in range(10000)]) + npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2) + def test_tuning_reset(self): """Re-use of the step method instance with cores=1 must not leak tuning information between chains.""" with Model() as pmodel: @@ -403,6 +403,40 @@ def test_tuning_reset(self): assert tuned != 0.1 np.testing.assert_array_equal(idata.sample_stats["scaling"].sel(chain=c).values, tuned) + @pytest.mark.parametrize( + "batched_dist", + ( + Binomial.dist(n=5, p=0.9), # scalar case + Binomial.dist(n=np.arange(40) + 1, p=np.linspace(0.1, 0.9, 40), shape=(40,)), + Binomial.dist( + n=(np.arange(20) + 1)[::-1], + p=np.linspace(0.1, 0.9, 20), + shape=( + 2, + 20, + ), + ), + Dirichlet.dist(a=np.ones(3) * (np.arange(40) + 1)[:, None], shape=(40, 3)), + Dirichlet.dist(a=np.ones(3) * (np.arange(20) + 1)[:, None], shape=(2, 20, 3)), + ), + ) + def test_elemwise_update(self, batched_dist): + with Model() as m: + m.register_rv(batched_dist, name="batched_dist") + step = pm.Metropolis([batched_dist]) + assert step.elemwise_update == (batched_dist.ndim > 0) + trace = pm.sample(draws=1000, chains=2, step=step, random_seed=428) + + assert az.rhat(trace).max()["batched_dist"].values < 1.1 + assert az.ess(trace).min()["batched_dist"].values > 50 + + def test_multinomial_no_elemwise_update(self): + with Model() as m: + batched_dist = Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4)) + with aesara.config.change_flags(mode=fast_unstable_sampling_mode): + step = pm.Metropolis([batched_dist]) + assert not step.elemwise_update + class TestDEMetropolisZ: def test_tuning_lambda_sequential(self): @@ -1217,8 +1251,6 @@ def perform(self, node, inputs, outputs): mout = [] coarse_models = [] - rng = np.random.RandomState(seed) - with Model() as coarse_model_0: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) @@ -1236,8 +1268,6 @@ def perform(self, node, inputs, outputs): coarse_models.append(coarse_model_0) - rng = np.random.RandomState(seed) - with Model() as coarse_model_1: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) @@ -1255,8 +1285,6 @@ def perform(self, node, inputs, outputs): coarse_models.append(coarse_model_1) - rng = np.random.RandomState(seed) - with Model() as model: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) @@ -1314,8 +1342,9 @@ def perform(self, node, inputs, outputs): (nchains, ndraws * nsub) ) Q_2_1 = np.concatenate(trace.get_sampler_stats("Q_2_1")).reshape((nchains, ndraws)) - assert Q_1_0.mean(axis=1) == 0.0 - assert Q_2_1.mean(axis=1) == 0.0 + # This used to be a scrict zero equality! + assert np.isclose(Q_1_0.mean(axis=1), 0.0, atol=1e-4) + assert np.isclose(Q_2_1.mean(axis=1), 0.0, atol=1e-4) class TestRVsAssignmentSteps: