Skip to content

Commit ba41e95

Browse files
committed
Do not add transformed value names to named_vars
This also disables prior_predictive sampling of transformed variables
1 parent ed8bfc0 commit ba41e95

File tree

6 files changed

+33
-77
lines changed

6 files changed

+33
-77
lines changed

pymc/model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,6 @@ def create_value_var(
14801480
value_var.tag.test_value = transform.forward(
14811481
value_var, *rv_var.owner.inputs
14821482
).tag.test_value
1483-
self.named_vars[value_var.name] = value_var
14841483
self.rvs_to_transforms[rv_var] = transform
14851484
self.rvs_to_values[rv_var] = value_var
14861485
self.values_to_rvs[value_var] = rv_var
@@ -1704,14 +1703,17 @@ def check_start_vals(self, start):
17041703
None
17051704
"""
17061705
start_points = [start] if isinstance(start, dict) else start
1706+
1707+
value_names_to_dtypes = {value.name: value.dtype for value in self.value_vars}
1708+
value_names_set = set(value_names_to_dtypes.keys())
17071709
for elem in start_points:
17081710

17091711
for k, v in elem.items():
1710-
elem[k] = np.asarray(v, dtype=self[k].dtype)
1712+
elem[k] = np.asarray(v, dtype=value_names_to_dtypes[k])
17111713

1712-
if not set(elem.keys()).issubset(self.named_vars.keys()):
1713-
extra_keys = ", ".join(set(elem.keys()) - set(self.named_vars.keys()))
1714-
valid_keys = ", ".join(self.named_vars.keys())
1714+
if not set(elem.keys()).issubset(value_names_set):
1715+
extra_keys = ", ".join(set(elem.keys()) - value_names_set)
1716+
valid_keys = ", ".join(value_names_set)
17151717
raise KeyError(
17161718
"Some start parameters do not appear in the model!\n"
17171719
f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"

pymc/sampling/forward.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def sample_prior_predictive(
343343
var_names : Iterable[str]
344344
A list of names of variables for which to compute the prior predictive
345345
samples. Defaults to both observed and unobserved RVs. Transformed values
346-
are not included unless explicitly defined in var_names.
346+
are not allowed.
347347
random_seed : int, RandomState or Generator, optional
348348
Seed for the random number generator.
349349
return_inferencedata : bool
@@ -382,23 +382,10 @@ def sample_prior_predictive(
382382
names = sorted(get_default_varnames(vars_, include_transformed=False))
383383
vars_to_sample = [model[name] for name in names]
384384

385-
# Any variables from var_names that are missing must be transformed variables.
386-
# Misspelled variables would have raised a KeyError above.
385+
# Any variables from var_names still missing are assumed to be transformed variables.
387386
missing_names = vars_.difference(names)
388-
for name in sorted(missing_names):
389-
transformed_value_var = model[name]
390-
rv_var = model.values_to_rvs[transformed_value_var]
391-
transform = model.rvs_to_transforms[rv_var]
392-
transformed_rv_var = transform.forward(rv_var, *rv_var.owner.inputs)
393-
394-
names.append(name)
395-
vars_to_sample.append(transformed_rv_var)
396-
397-
# If the user asked for the transformed variable in var_names, but not the
398-
# original RV, we add it manually here
399-
if rv_var.name not in names:
400-
names.append(rv_var.name)
401-
vars_to_sample.append(rv_var)
387+
if missing_names:
388+
raise ValueError(f"Unrecognized var_names: {missing_names}")
402389

403390
if random_seed is not None:
404391
(random_seed,) = _get_seeds_per_chain(random_seed, 1)

pymc/smc/kernels.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
)
3434
from pymc.backends.ndarray import NDArray
3535
from pymc.blocking import DictToArrayBijection
36+
from pymc.initial_point import make_initial_point_expression
3637
from pymc.model import Point, modelcontext
37-
from pymc.sampling.forward import sample_prior_predictive
38+
from pymc.sampling.forward import draw
3839
from pymc.step_methods.metropolis import MultivariateNormalProposal
3940
from pymc.vartypes import discrete_types
4041

@@ -182,13 +183,20 @@ def initialize_population(self) -> Dict[str, np.ndarray]:
182183
"ignore", category=UserWarning, message="The effect of Potentials"
183184
)
184185

185-
result = sample_prior_predictive(
186-
self.draws,
187-
var_names=[v.name for v in self.model.unobserved_value_vars],
188-
model=self.model,
189-
return_inferencedata=False,
186+
model = self.model
187+
prior_expression = make_initial_point_expression(
188+
free_rvs=model.free_RVs,
189+
rvs_to_transforms=model.rvs_to_transforms,
190+
initval_strategies={},
191+
default_strategy="prior",
192+
return_transformed=True,
190193
)
191-
return cast(Dict[str, np.ndarray], result)
194+
prior_values = draw(prior_expression, draws=self.draws, random_seed=self.rng)
195+
196+
names = [model.rvs_to_values[rv].name for rv in model.free_RVs]
197+
dict_prior = {k: np.stack(v) for k, v in zip(names, prior_values)}
198+
199+
return cast(Dict[str, np.ndarray], dict_prior)
192200

193201
def _initialize_kernel(self):
194202
"""Create variables and logp function necessary to run kernel
@@ -325,12 +333,11 @@ def _posterior_to_trace(self, chain=0) -> NDArray:
325333
for i in range(lenght_pos):
326334
value = []
327335
size = 0
328-
for varname in varnames:
329-
shape, new_size = self.var_info[varname]
336+
for var in self.variables:
337+
shape, new_size = self.var_info[var.name]
330338
var_samples = self.tempered_posterior[i][size : size + new_size]
331339
# Round discrete variable samples. The rounded values were the ones
332340
# actually used in the logp evaluations (see logp_forw)
333-
var = self.model[varname]
334341
if var.dtype in discrete_types:
335342
var_samples = np.round(var_samples).astype(var.dtype)
336343
value.append(var_samples.reshape(shape))

pymc/tests/distributions/test_continuous.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def random_polyagamma(*args, **kwargs):
7070

7171
class TestBoundedContinuous:
7272
def get_dist_params_and_interval_bounds(self, model, rv_name):
73-
interval_rv = model.named_vars[f"{rv_name}_interval__"]
7473
rv = model.named_vars[rv_name]
7574
dist_params = rv.owner.inputs
7675
lower_interval, upper_interval = model.rvs_to_transforms[rv].args_fn(*rv.owner.inputs)

pymc/tests/sampling/test_forward.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,51 +1175,13 @@ def test_potentials_warning(self):
11751175
with pytest.warns(UserWarning, match=warning_msg):
11761176
pm.sample_prior_predictive(samples=5)
11771177

1178-
def test_transformed_vars(self):
1179-
# Test that prior predictive returns transformation of RVs when these are
1180-
# passed explicitly in `var_names`
1181-
1182-
def ub_interval_forward(x, ub):
1183-
# Interval transform assuming lower bound is zero
1184-
return np.log(x - 0) - np.log(ub - x)
1185-
1178+
def test_transformed_vars_not_supported(self):
11861179
with pm.Model() as model:
11871180
ub = pm.HalfNormal("ub", 10)
11881181
x = pm.Uniform("x", 0, ub)
11891182

1190-
prior = pm.sample_prior_predictive(
1191-
var_names=["ub", "ub_log__", "x", "x_interval__"],
1192-
samples=10,
1193-
random_seed=123,
1194-
)
1195-
1196-
# Check values are correct
1197-
assert np.allclose(prior.prior["ub_log__"].data, np.log(prior.prior["ub"].data))
1198-
assert np.allclose(
1199-
prior.prior["x_interval__"].data,
1200-
ub_interval_forward(prior.prior["x"].data, prior.prior["ub"].data),
1201-
)
1202-
1203-
# Check that it works when the original RVs are not mentioned in var_names
1204-
with pm.Model() as model_transformed_only:
1205-
ub = pm.HalfNormal("ub", 10)
1206-
x = pm.Uniform("x", 0, ub)
1207-
1208-
prior_transformed_only = pm.sample_prior_predictive(
1209-
var_names=["ub_log__", "x_interval__"],
1210-
samples=10,
1211-
random_seed=123,
1212-
)
1213-
assert (
1214-
"ub" not in prior_transformed_only.prior.data_vars
1215-
and "x" not in prior_transformed_only.prior.data_vars
1216-
)
1217-
assert np.allclose(
1218-
prior.prior["ub_log__"].data, prior_transformed_only.prior["ub_log__"].data
1219-
)
1220-
assert np.allclose(
1221-
prior.prior["x_interval__"], prior_transformed_only.prior["x_interval__"].data
1222-
)
1183+
with pytest.raises(ValueError, match="Unrecognized var_names"):
1184+
pm.sample_prior_predictive(var_names=["ub", "ub_log__", "x", "x_interval__"])
12231185

12241186
def test_issue_4490(self):
12251187
# Test that samples do not depend on var_name order or, more fundamentally,

pymc/tests/test_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,9 +1206,8 @@ def test_interval_missing_observations(self):
12061206
with pytest.warns(ImputationWarning):
12071207
theta2 = pm.Normal("theta2", mu=theta1, observed=obs2, rng=rng)
12081208

1209-
assert "theta1_observed" in model.named_vars
1210-
assert "theta1_missing_interval__" in model.named_vars
1211-
assert model.rvs_to_transforms[model.named_vars["theta1_observed"]] is None
1209+
assert isinstance(model.rvs_to_transforms[model["theta1_missing"]], IntervalTransform)
1210+
assert model.rvs_to_transforms[model["theta1_observed"]] is None
12121211

12131212
prior_trace = pm.sample_prior_predictive(return_inferencedata=False)
12141213

0 commit comments

Comments
 (0)