Skip to content

Commit 4a2bf0e

Browse files
danhphanmichaelosthege
authored andcommitted
add constant_data into sample_numpyro_nuts()
1 parent 7cc24bc commit 4a2bf0e

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

pymc/backends/arviz.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,22 @@ def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
6363
return observations
6464

6565

66+
def find_constants(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
67+
"""If there are constants available, return them as a dictionary."""
68+
if model is None or not model.named_vars:
69+
return None
70+
71+
constants = {}
72+
for name, var in model.named_vars.items():
73+
if isinstance(var, (Constant, SharedVariable)):
74+
if hasattr(var, "data"):
75+
var = var.data
76+
elif hasattr(var, "get_value"):
77+
var = var.get_value()
78+
constants[name] = var
79+
return constants
80+
81+
6682
class _DefaultTrace:
6783
"""
6884
Utility for collecting samples into a dictionary.

pymc/sampling_jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from arviz.data.base import make_attrs
3131

3232
from pymc import Model, modelcontext
33-
from pymc.backends.arviz import find_observations
33+
from pymc.backends.arviz import find_constants, find_observations
3434
from pymc.util import get_default_varnames
3535

3636
warnings.warn("This module is experimental.")
@@ -564,6 +564,7 @@ def sample_numpyro_nuts(
564564
posterior=posterior,
565565
log_likelihood=log_likelihood,
566566
observed_data=find_observations(model),
567+
constant_data=find_constants(model),
567568
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
568569
coords=coords,
569570
dims=dims,

0 commit comments

Comments
 (0)