Skip to content

Commit 3a884dd

Browse files
committed
Distinguish better observed from constant data
This avoids needing to set dummy observed data when doing sample_posterior_predictive when that is not part of the generative graph.
1 parent ba60b79 commit 3a884dd

File tree

6 files changed

+106
-35
lines changed

6 files changed

+106
-35
lines changed

pymc/backends/arviz.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from arviz import InferenceData, concat, rcParams
3232
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
33-
from pytensor.graph.basic import Constant
33+
from pytensor.graph import ancestors
3434
from pytensor.tensor.sharedvar import SharedVariable
3535
from rich.progress import Console, Progress
3636
from rich.theme import Theme
@@ -72,31 +72,21 @@ def find_observations(model: "Model") -> dict[str, Var]:
7272

7373
def find_constants(model: "Model") -> dict[str, Var]:
7474
"""If there are constants available, return them as a dictionary."""
75+
model_vars = model.basic_RVs + model.deterministics + model.potentials
76+
value_vars = set(model.rvs_to_values.values())
7577

76-
# The constant data vars must be either pm.Data or TensorConstant or SharedVariable
77-
def is_data(name, var, model) -> bool:
78-
observations = find_observations(model)
79-
return (
80-
var not in model.deterministics
81-
and var not in model.observed_RVs
82-
and var not in model.free_RVs
83-
and var not in model.potentials
84-
and var not in model.value_vars
85-
and name not in observations
86-
and isinstance(var, Constant | SharedVariable)
87-
)
88-
89-
# The assumption is that constants (like pm.Data) are named
90-
# variables that aren't observed or free RVs, nor are they
91-
# deterministics, and then we eliminate observations.
9278
constant_data = {}
93-
for name, var in model.named_vars.items():
94-
if is_data(name, var, model):
95-
if hasattr(var, "get_value"):
96-
var = var.get_value()
97-
elif hasattr(var, "data"):
98-
var = var.data
99-
constant_data[name] = var
79+
for var in model.data_vars:
80+
if var in value_vars:
81+
# An observed value variable could also be part of the generative graph
82+
if var not in ancestors(model_vars):
83+
continue
84+
85+
if isinstance(var, SharedVariable):
86+
var_value = var.get_value()
87+
else:
88+
var_value = var.data
89+
constant_data[var.name] = var_value
10090

10191
return constant_data
10292

pymc/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,6 @@ def Data(
444444
length=xshape[d],
445445
)
446446

447-
model.add_named_variable(x, dims=dims)
447+
model.register_data_var(x, dims=dims)
448448

449449
return x

pymc/model/core.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def __init__(
531531
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
532532
self.deterministics = treelist(parent=self.parent.deterministics)
533533
self.potentials = treelist(parent=self.parent.potentials)
534+
self.data_vars = treelist(parent=self.parent.data_vars)
534535
self._coords = self.parent._coords
535536
self._dim_lengths = self.parent._dim_lengths
536537
else:
@@ -544,6 +545,7 @@ def __init__(
544545
self.observed_RVs = treelist()
545546
self.deterministics = treelist()
546547
self.potentials = treelist()
548+
self.data_vars = treelist()
547549
self._coords = {}
548550
self._dim_lengths = {}
549551
self.add_coords(coords)
@@ -1483,6 +1485,11 @@ def create_value_var(
14831485

14841486
return value_var
14851487

1488+
def register_data_var(self, data, dims=None):
1489+
"""Register a data variable with the model."""
1490+
self.data_vars.append(data)
1491+
self.add_named_variable(data, dims=dims)
1492+
14861493
def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None):
14871494
"""Add a random graph variable to the named variables of the model.
14881495

pymc/model/fgraph.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,30 +164,30 @@ def fgraph_from_model(
164164
free_rvs = model.free_RVs
165165
observed_rvs = model.observed_RVs
166166
potentials = model.potentials
167-
named_vars = model.named_vars.values()
168167
# We copy Deterministics (Identity Op) so that they don't show in between "main" variables
169168
# We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator
170169
old_deterministics = model.deterministics
171170
deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics]
172171
# Value variables (we also have to decide whether to inline named ones)
173172
old_value_vars = list(rvs_to_values.values())
174-
unnamed_value_vars = [val for val in old_value_vars if val not in named_vars]
173+
data_vars = model.data_vars
174+
unnamed_value_vars = [val for val in old_value_vars if val not in data_vars]
175175
named_value_vars = [
176-
val if inlined_views else val.copy(val.name) for val in old_value_vars if val in named_vars
176+
val if inlined_views else val.copy(name=val.name)
177+
for val in old_value_vars
178+
if val in data_vars
177179
]
178180
value_vars = old_value_vars.copy()
179181
if inlined_views:
180182
# In this case we want to use the named_value_vars as the value_vars in RVs
181183
for named_val in named_value_vars:
182184
idx = value_vars.index(named_val)
183185
value_vars[idx] = named_val
184-
# Other variables that are in named_vars but are not any of the categories above (e.g., Data)
185-
# We use the same trick as deterministics!
186-
accounted_for = set(free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars)
186+
# Data vars that are not value vars
187187
other_named_vars = [
188188
var if inlined_views else var.copy(var.name)
189-
for var in named_vars
190-
if var not in accounted_for
189+
for var in data_vars
190+
if var not in old_value_vars
191191
]
192192

193193
model_vars = (
@@ -355,6 +355,7 @@ def first_non_model_var(var):
355355
model.deterministics.append(var)
356356
elif isinstance(model_var.owner.op, ModelNamed):
357357
var, *dims = model_var.owner.inputs
358+
model.data_vars.append(var)
358359
else:
359360
raise TypeError(f"Unexpected ModelVar type {type(model_var)}")
360361

tests/backends/test_arviz.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,14 +454,38 @@ def test_constant_data(self, use_context):
454454
test_dict = {
455455
"posterior": ["beta"],
456456
"observed_data": ["obs"],
457-
"constant_data": ["x", "y", "beta_sigma"],
457+
"constant_data": ["x", "beta_sigma"],
458458
}
459459
fails = check_multiple_attrs(test_dict, inference_data)
460460
assert not fails
461461
assert inference_data.log_likelihood["obs"].shape == (2, 100, 3)
462462
# test that scalars are dimensionless in constant_data (issue #6755)
463463
assert inference_data.constant_data["beta_sigma"].ndim == 0
464464

465+
@pytest.mark.parametrize("constant_in_generative_graph", [True, False])
466+
def test_observed_data_also_constant(self, constant_in_generative_graph):
467+
"""Test that wen the same variable is used as constant data and observed data, it shows up in both groups."""
468+
with pm.Model(coords={"trial": [0, 1, 2]}) as model:
469+
x = pm.Data("x", [1.0, 2.0, 3.0], dims=["trial"])
470+
sigma = pm.HalfNormal("sigma", 1)
471+
mu = x - 1 if constant_in_generative_graph else 0
472+
pm.Normal("y", mu, sigma, observed=x, dims=["trial"])
473+
474+
trace = pm.sample_prior_predictive(100, return_inferencedata=False)
475+
476+
inference_data = to_inference_data(prior=trace, model=model, log_likelihood=False)
477+
478+
test_dict = {
479+
"prior": ["sigma"],
480+
"observed_data": ["y"],
481+
}
482+
if constant_in_generative_graph:
483+
test_dict["constant_data"] = ["x"]
484+
else:
485+
test_dict["~constant_data"] = []
486+
fails = check_multiple_attrs(test_dict, inference_data)
487+
assert not fails
488+
465489
def test_predictions_constant_data(self):
466490
with pm.Model():
467491
x = pm.Data("x", [1.0, 2.0, 3.0])
@@ -548,7 +572,7 @@ def test_priors_separation(self, use_context):
548572
"prior": ["beta", "~obs"],
549573
"observed_data": ["obs"],
550574
"prior_predictive": ["obs"],
551-
"constant_data": ["x", "y"],
575+
"constant_data": ["x"],
552576
}
553577
if use_context:
554578
with model:

tests/sampling/test_forward.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,55 @@ def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results,
10171017
]
10181018
caplog.clear()
10191019

1020+
def test_observed_data_needed_in_pp(self):
1021+
# Model where y_data is not part of the generative graph.
1022+
# It shouldn't be needed to set a dummy value for posterior predictive sampling
1023+
1024+
with pm.Model(coords={"trial": range(5), "feature": range(3)}) as m:
1025+
x_data = pm.Data("x_data", np.random.normal(size=(5, 3)), dims=("trial", "feat"))
1026+
y_data = pm.Data("y_data", np.random.normal(size=(5,)), dims=("trial",))
1027+
sigma = pm.HalfNormal("sigma")
1028+
mu = x_data.sum(-1)
1029+
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))
1030+
1031+
prior = pm.sample_prior_predictive(samples=25).prior
1032+
1033+
fake_idata = InferenceData(posterior=prior)
1034+
1035+
new_coords = {"trial": range(2), "feature": range(3)}
1036+
new_x_data = np.random.normal(size=(2, 3))
1037+
with m:
1038+
pm.set_data(
1039+
{
1040+
"x_data": new_x_data,
1041+
},
1042+
coords=new_coords,
1043+
)
1044+
pp = pm.sample_posterior_predictive(fake_idata, predictions=True, progressbar=False)
1045+
assert pp.predictions["y"].shape == (1, 25, 2)
1046+
1047+
# In this case y_data is part of the generative graph, so we must set it to a compatible value
1048+
with pm.Model(coords={"trial": range(5), "feature": range(3)}) as m:
1049+
x_data = pm.Data("x_data", np.random.normal(size=(5, 3)), dims=("trial", "feat"))
1050+
y_data = pm.Data("y_data", np.random.normal(size=(5,)), dims=("trial",))
1051+
sigma = pm.HalfNormal("sigma")
1052+
mu = (y_data.sum() * x_data).sum(-1)
1053+
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))
1054+
1055+
prior = pm.sample_prior_predictive(samples=25).prior
1056+
1057+
fake_idata = InferenceData(posterior=prior)
1058+
1059+
with m:
1060+
pm.set_data({"x_data": new_x_data}, coords=new_coords)
1061+
with pytest.raises(ValueError, match="conflicting sizes for dimension 'trial'"):
1062+
pm.sample_posterior_predictive(fake_idata, predictions=True, progressbar=False)
1063+
1064+
new_y_data = np.random.normal(size=(2,))
1065+
with m:
1066+
pm.set_data({"y_data": new_y_data})
1067+
assert pp.predictions["y"].shape == (1, 25, 2)
1068+
10201069

10211070
@pytest.fixture(scope="class")
10221071
def point_list_arg_bug_fixture() -> tuple[pm.Model, pm.backends.base.MultiTrace]:

0 commit comments

Comments
 (0)