From 4a2bf0eab89f78d4f57c9f364ba955adf86a42ec Mon Sep 17 00:00:00 2001 From: danhphan Date: Thu, 26 May 2022 14:39:57 +1000 Subject: [PATCH 1/3] add constant_data into sample_numpyro_nuts() --- pymc/backends/arviz.py | 16 ++++++++++++++++ pymc/sampling_jax.py | 3 ++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 80da0a89d8..e0ca8fbba5 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -63,6 +63,22 @@ def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]: return observations +def find_constants(model: Optional["Model"]) -> Optional[Dict[str, Var]]: + """If there are constants available, return them as a dictionary.""" + if model is None or not model.named_vars: + return None + + constants = {} + for name, var in model.named_vars.items(): + if isinstance(var, (Constant, SharedVariable)): + if hasattr(var, "data"): + var = var.data + elif hasattr(var, "get_value"): + var = var.get_value() + constants[name] = var + return constants + + class _DefaultTrace: """ Utility for collecting samples into a dictionary. diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 644e9dd5ea..55f96c1c2a 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -30,7 +30,7 @@ from arviz.data.base import make_attrs from pymc import Model, modelcontext -from pymc.backends.arviz import find_observations +from pymc.backends.arviz import find_constants, find_observations from pymc.util import get_default_varnames warnings.warn("This module is experimental.") @@ -564,6 +564,7 @@ def sample_numpyro_nuts( posterior=posterior, log_likelihood=log_likelihood, observed_data=find_observations(model), + constant_data=find_constants(model), sample_stats=_sample_stats_to_xarray(pmap_numpyro), coords=coords, dims=dims, From f8d0023942c8839eae0dfd7e2e71eacd67b0b021 Mon Sep 17 00:00:00 2001 From: danhphan Date: Sun, 29 May 2022 13:46:27 +1000 Subject: [PATCH 2/3] refactor constant_data_to_xarray with find_constants --- pymc/backends/arviz.py | 76 +++++++++++++++++------------------------- pymc/sampling_jax.py | 1 + 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index e0ca8fbba5..1801bf431e 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -43,10 +43,10 @@ Var = Any # pylint: disable=invalid-name -def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]: +def find_observations(model: Optional["Model"]) -> Dict[str, Var]: """If there are observations available, return them as a dictionary.""" if model is None: - return None + return {} observations = {} for obs in model.observed_RVs: @@ -63,20 +63,37 @@ def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]: return observations -def find_constants(model: Optional["Model"]) -> Optional[Dict[str, Var]]: +def find_constants(model: Optional["Model"]) -> Dict[str, Var]: """If there are constants available, return them as a dictionary.""" - if model is None or not model.named_vars: - return None + # The constant data vars must be either pm.Data or TensorConstant or SharedVariable + if model is None: + return {} + + def is_data(name, var, model) -> bool: + observations = find_observations(model) + return ( + var not in model.deterministics + and var not in model.observed_RVs + and var not in model.free_RVs + and var not in model.potentials + and var not in model.value_vars + and name not in observations + and isinstance(var, (Constant, SharedVariable)) + ) - constants = {} + # The assumption is that constants (like pm.Data) are named + # variables that aren't observed or free RVs, nor are they + # deterministics, and then we eliminate observations. + constant_data = {} for name, var in model.named_vars.items(): - if isinstance(var, (Constant, SharedVariable)): - if hasattr(var, "data"): - var = var.data - elif hasattr(var, "get_value"): + if is_data(name, var, model): + if hasattr(var, "get_value"): var = var.get_value() - constants[name] = var - return constants + elif hasattr(var, "data"): + var = var.data + constant_data[name] = var + + return constant_data class _DefaultTrace: @@ -483,41 +500,10 @@ def observed_data_to_xarray(self): @requires("model") def constant_data_to_xarray(self): """Convert constant data to xarray.""" - # For constant data, we are concerned only with deterministics and - # data. The constant data vars must be either pm.Data - # (TensorConstant/SharedVariable) or pm.Deterministic - constant_data_vars = {} # type: Dict[str, Var] - - def is_data(name, var) -> bool: - assert self.model is not None - return ( - var not in self.model.deterministics - and var not in self.model.observed_RVs - and var not in self.model.free_RVs - and var not in self.model.potentials - and var not in self.model.value_vars - and (self.observations is None or name not in self.observations) - and isinstance(var, (Constant, SharedVariable)) - ) - - # I don't know how to find pm.Data, except that they are named - # variables that aren't observed or free RVs, nor are they - # deterministics, and then we eliminate observations. - for name, var in self.model.named_vars.items(): - if is_data(name, var): - constant_data_vars[name] = var - - if not constant_data_vars: + constant_data = find_constants(self.model) + if not constant_data: return None - constant_data = {} - for name, vals in constant_data_vars.items(): - if hasattr(vals, "get_value"): - vals = vals.get_value() - elif hasattr(vals, "data"): - vals = vals.data - constant_data[name] = vals - return dict_to_dataset( constant_data, library=pymc, diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 55f96c1c2a..b7834acc88 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -370,6 +370,7 @@ def sample_blackjax_nuts( posterior=posterior, log_likelihood=log_likelihood, observed_data=find_observations(model), + constant_data=find_constants(model), coords=coords, dims=dims, attrs=make_attrs(attrs, library=blackjax), From 6f50e3776e2286df6d7b124d7f439e2c741ba88f Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 20 Jun 2022 12:54:46 +0200 Subject: [PATCH 3/3] Test observations and constant data are included in InferenceData from JAX sampling --- pymc/backends/arviz.py | 10 ++-------- pymc/tests/test_sampling_jax.py | 5 ++++- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 1801bf431e..826f64d6f9 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -43,11 +43,8 @@ Var = Any # pylint: disable=invalid-name -def find_observations(model: Optional["Model"]) -> Dict[str, Var]: +def find_observations(model: "Model") -> Dict[str, Var]: """If there are observations available, return them as a dictionary.""" - if model is None: - return {} - observations = {} for obs in model.observed_RVs: aux_obs = getattr(obs.tag, "observations", None) @@ -63,12 +60,9 @@ def find_observations(model: Optional["Model"]) -> Dict[str, Var]: return observations -def find_constants(model: Optional["Model"]) -> Dict[str, Var]: +def find_constants(model: "Model") -> Dict[str, Var]: """If there are constants available, return them as a dictionary.""" # The constant data vars must be either pm.Data or TensorConstant or SharedVariable - if model is None: - return {} - def is_data(name, var, model) -> bool: observations = find_observations(model) return ( diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index e7d5d5e828..6624149f2e 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -171,8 +171,9 @@ def test_get_jaxified_logp(): def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend): with pm.Model() as m: x = pm.Normal("x") - z = pm.Normal("z") y = pm.Normal("y", x, observed=0) + pm.ConstantData("constantdata", [1, 2, 3]) + pm.MutableData("mutabledata", 2) idata = sampler( tune=50, draws=50, @@ -180,6 +181,8 @@ def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend): idata_kwargs=idata_kwargs, postprocessing_backend=postprocessing_backend, ) + assert "constantdata" in idata.constant_data + assert "mutabledata" in idata.constant_data if idata_kwargs.get("log_likelihood", True): assert "log_likelihood" in idata