Skip to content

Commit 403f2d5

Browse files
Add constant_data into sample_numpyro_nuts() (#5807)
Add constant_data into sample_numpyro_nuts() Co-authored-by: Michael Osthege <[email protected]>
1 parent 45a816e commit 403f2d5

File tree

3 files changed

+40
-39
lines changed

3 files changed

+40
-39
lines changed

pymc/backends/arviz.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,8 @@
4343
Var = Any # pylint: disable=invalid-name
4444

4545

46-
def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
46+
def find_observations(model: "Model") -> Dict[str, Var]:
4747
"""If there are observations available, return them as a dictionary."""
48-
if model is None:
49-
return None
50-
5148
observations = {}
5249
for obs in model.observed_RVs:
5350
aux_obs = getattr(obs.tag, "observations", None)
@@ -63,6 +60,36 @@ def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
6360
return observations
6461

6562

63+
def find_constants(model: "Model") -> Dict[str, Var]:
64+
"""If there are constants available, return them as a dictionary."""
65+
# The constant data vars must be either pm.Data or TensorConstant or SharedVariable
66+
def is_data(name, var, model) -> bool:
67+
observations = find_observations(model)
68+
return (
69+
var not in model.deterministics
70+
and var not in model.observed_RVs
71+
and var not in model.free_RVs
72+
and var not in model.potentials
73+
and var not in model.value_vars
74+
and name not in observations
75+
and isinstance(var, (Constant, SharedVariable))
76+
)
77+
78+
# The assumption is that constants (like pm.Data) are named
79+
# variables that aren't observed or free RVs, nor are they
80+
# deterministics, and then we eliminate observations.
81+
constant_data = {}
82+
for name, var in model.named_vars.items():
83+
if is_data(name, var, model):
84+
if hasattr(var, "get_value"):
85+
var = var.get_value()
86+
elif hasattr(var, "data"):
87+
var = var.data
88+
constant_data[name] = var
89+
90+
return constant_data
91+
92+
6693
class _DefaultTrace:
6794
"""
6895
Utility for collecting samples into a dictionary.
@@ -467,41 +494,10 @@ def observed_data_to_xarray(self):
467494
@requires("model")
468495
def constant_data_to_xarray(self):
469496
"""Convert constant data to xarray."""
470-
# For constant data, we are concerned only with deterministics and
471-
# data. The constant data vars must be either pm.Data
472-
# (TensorConstant/SharedVariable) or pm.Deterministic
473-
constant_data_vars = {} # type: Dict[str, Var]
474-
475-
def is_data(name, var) -> bool:
476-
assert self.model is not None
477-
return (
478-
var not in self.model.deterministics
479-
and var not in self.model.observed_RVs
480-
and var not in self.model.free_RVs
481-
and var not in self.model.potentials
482-
and var not in self.model.value_vars
483-
and (self.observations is None or name not in self.observations)
484-
and isinstance(var, (Constant, SharedVariable))
485-
)
486-
487-
# I don't know how to find pm.Data, except that they are named
488-
# variables that aren't observed or free RVs, nor are they
489-
# deterministics, and then we eliminate observations.
490-
for name, var in self.model.named_vars.items():
491-
if is_data(name, var):
492-
constant_data_vars[name] = var
493-
494-
if not constant_data_vars:
497+
constant_data = find_constants(self.model)
498+
if not constant_data:
495499
return None
496500

497-
constant_data = {}
498-
for name, vals in constant_data_vars.items():
499-
if hasattr(vals, "get_value"):
500-
vals = vals.get_value()
501-
elif hasattr(vals, "data"):
502-
vals = vals.data
503-
constant_data[name] = vals
504-
505501
return dict_to_dataset(
506502
constant_data,
507503
library=pymc,

pymc/sampling_jax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from arviz.data.base import make_attrs
3131

3232
from pymc import Model, modelcontext
33-
from pymc.backends.arviz import find_observations
33+
from pymc.backends.arviz import find_constants, find_observations
3434
from pymc.util import get_default_varnames
3535

3636
warnings.warn("This module is experimental.")
@@ -370,6 +370,7 @@ def sample_blackjax_nuts(
370370
posterior=posterior,
371371
log_likelihood=log_likelihood,
372372
observed_data=find_observations(model),
373+
constant_data=find_constants(model),
373374
coords=coords,
374375
dims=dims,
375376
attrs=make_attrs(attrs, library=blackjax),
@@ -564,6 +565,7 @@ def sample_numpyro_nuts(
564565
posterior=posterior,
565566
log_likelihood=log_likelihood,
566567
observed_data=find_observations(model),
568+
constant_data=find_constants(model),
567569
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
568570
coords=coords,
569571
dims=dims,

pymc/tests/test_sampling_jax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,18 @@ def test_get_jaxified_logp():
171171
def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):
172172
with pm.Model() as m:
173173
x = pm.Normal("x")
174-
z = pm.Normal("z")
175174
y = pm.Normal("y", x, observed=0)
175+
pm.ConstantData("constantdata", [1, 2, 3])
176+
pm.MutableData("mutabledata", 2)
176177
idata = sampler(
177178
tune=50,
178179
draws=50,
179180
chains=1,
180181
idata_kwargs=idata_kwargs,
181182
postprocessing_backend=postprocessing_backend,
182183
)
184+
assert "constantdata" in idata.constant_data
185+
assert "mutabledata" in idata.constant_data
183186

184187
if idata_kwargs.get("log_likelihood", True):
185188
assert "log_likelihood" in idata

0 commit comments

Comments
 (0)