Skip to content

Commit 6f50e37

Browse files
Test observations and constant data are included in InferenceData from JAX sampling
1 parent f8d0023 commit 6f50e37

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

pymc/backends/arviz.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,8 @@
4343
Var = Any # pylint: disable=invalid-name
4444

4545

46-
def find_observations(model: Optional["Model"]) -> Dict[str, Var]:
46+
def find_observations(model: "Model") -> Dict[str, Var]:
4747
"""If there are observations available, return them as a dictionary."""
48-
if model is None:
49-
return {}
50-
5148
observations = {}
5249
for obs in model.observed_RVs:
5350
aux_obs = getattr(obs.tag, "observations", None)
@@ -63,12 +60,9 @@ def find_observations(model: Optional["Model"]) -> Dict[str, Var]:
6360
return observations
6461

6562

66-
def find_constants(model: Optional["Model"]) -> Dict[str, Var]:
63+
def find_constants(model: "Model") -> Dict[str, Var]:
6764
"""If there are constants available, return them as a dictionary."""
6865
# The constant data vars must be either pm.Data or TensorConstant or SharedVariable
69-
if model is None:
70-
return {}
71-
7266
def is_data(name, var, model) -> bool:
7367
observations = find_observations(model)
7468
return (

pymc/tests/test_sampling_jax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,18 @@ def test_get_jaxified_logp():
171171
def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):
172172
with pm.Model() as m:
173173
x = pm.Normal("x")
174-
z = pm.Normal("z")
175174
y = pm.Normal("y", x, observed=0)
175+
pm.ConstantData("constantdata", [1, 2, 3])
176+
pm.MutableData("mutabledata", 2)
176177
idata = sampler(
177178
tune=50,
178179
draws=50,
179180
chains=1,
180181
idata_kwargs=idata_kwargs,
181182
postprocessing_backend=postprocessing_backend,
182183
)
184+
assert "constantdata" in idata.constant_data
185+
assert "mutabledata" in idata.constant_data
183186

184187
if idata_kwargs.get("log_likelihood", True):
185188
assert "log_likelihood" in idata

0 commit comments

Comments
 (0)