diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index b274af10b6..b63f68acc6 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -67,6 +67,7 @@ import numpy as np +from pytensor.tensor.variable import TensorVariable from typing_extensions import TypeAlias from pymc.backends.arviz import predictions_to_inference_data, to_inference_data @@ -99,11 +100,12 @@ def _init_trace( stats_dtypes: list[dict[str, type]], trace: Optional[BaseTrace], model: Model, + trace_vars: Optional[list[TensorVariable]] = None, ) -> BaseTrace: """Initializes a trace backend for a chain.""" strace: BaseTrace if trace is None: - strace = NDArray(model=model) + strace = NDArray(model=model, vars=trace_vars) elif isinstance(trace, BaseTrace): if len(trace) > 0: raise ValueError("Continuation of traces is no longer supported.") @@ -123,6 +125,7 @@ def init_traces( step: Union[BlockedStep, CompoundStep], initial_point: Mapping[str, np.ndarray], model: Model, + trace_vars: Optional[list[TensorVariable]] = None, ) -> tuple[Optional[RunType], Sequence[IBaseTrace]]: """Initializes a trace recorder for each chain.""" if HAS_MCB and isinstance(backend, Backend): @@ -142,6 +145,7 @@ def init_traces( chain_number=chain_number, trace=backend, model=model, + trace_vars=trace_vars, ) for chain_number in range(chains) ] diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 01f8b0d502..f048cc2938 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -532,15 +532,19 @@ def sample_jax_nuts( model = modelcontext(model) - if var_names is None: - var_names = model.unobserved_value_vars + if var_names is not None: + filtered_var_names = [v for v in model.unobserved_value_vars if v.name in var_names] + else: + filtered_var_names = model.unobserved_value_vars if nuts_kwargs is None: nuts_kwargs = {} else: nuts_kwargs = nuts_kwargs.copy() - vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) + vars_to_sample = list( + get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) + ) (random_seed,) = _get_seeds_per_chain(random_seed, 1) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 313a6ab8c0..dd97f78c88 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -264,6 +264,7 @@ def _sample_external_nuts( random_seed: Union[RandomState, None], initvals: Union[StartDict, Sequence[Optional[StartDict]], None], model: Model, + var_names: Optional[Sequence[str]], progressbar: bool, idata_kwargs: Optional[dict], nuts_sampler_kwargs: Optional[dict], @@ -292,6 +293,11 @@ def _sample_external_nuts( "`idata_kwargs` are currently ignored by the nutpie sampler", UserWarning, ) + if var_names is not None: + warnings.warn( + "`var_names` are currently ignored by the nutpie sampler", + UserWarning, + ) compiled_model = nutpie.compile_pymc_model(model) t_start = time.time() idata = nutpie.sample( @@ -348,6 +354,7 @@ def _sample_external_nuts( random_seed=random_seed, initvals=initvals, model=model, + var_names=var_names, progressbar=progressbar, nuts_sampler=sampler, idata_kwargs=idata_kwargs, @@ -371,6 +378,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, + var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", @@ -399,6 +407,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, + var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", @@ -427,6 +436,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, + var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", @@ -478,6 +488,8 @@ def sample( A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step method will be used, if appropriate to the model. + var_names : list of str, optional + Names of variables to be stored in the trace. Defaults to all free variables and deterministics. nuts_sampler : str Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. This requires the chosen sampler to be installed. @@ -680,6 +692,7 @@ def sample( random_seed=random_seed, initvals=initvals, model=model, + var_names=var_names, progressbar=progressbar, idata_kwargs=idata_kwargs, nuts_sampler_kwargs=nuts_sampler_kwargs, @@ -722,12 +735,19 @@ def sample( model.check_start_vals(ip) _check_start_shape(model, ip) + if var_names is not None: + trace_vars = [v for v in model.unobserved_RVs if v.name in var_names] + assert len(trace_vars) == len(var_names), "Not all var_names were found in the model" + else: + trace_vars = None + # Create trace backends for each chain run, traces = init_traces( backend=trace, chains=chains, expected_length=draws + tune, step=step, + trace_vars=trace_vars, initial_point=ip, model=model, ) @@ -739,6 +759,7 @@ def sample( "traces": traces, "chains": chains, "tune": tune, + "var_names": var_names, "progressbar": progressbar, "model": model, "cores": cores, diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index d8d0cae246..77121b37ce 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -491,6 +491,15 @@ def test_sample_partially_observed(): assert idata.posterior["x"].shape == (1, 10, 3) +def test_sample_var_names(): + with pm.Model() as model: + a = pm.Normal("a") + b = pm.Deterministic("b", a**2) + idata = pm.sample(10, tune=10, nuts_sampler="numpyro", var_names=["a"]) + assert "a" in idata.posterior + assert "b" not in idata.posterior + + @pytest.mark.parametrize("nuts_sampler", ("numpyro", "blackjax")) def test_convergence_warnings(caplog, nuts_sampler): with pm.Model() as m: diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 0fc03dd631..a18430818d 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -694,6 +694,15 @@ def test_no_init_nuts_compound(caplog): assert "Initializing NUTS" not in caplog.text +def test_sample_var_names(): + with pm.Model() as model: + a = pm.Normal("a") + b = pm.Deterministic("b", a**2) + idata = pm.sample(10, tune=10, var_names=["a"]) + assert "a" in idata.posterior + assert "b" not in idata.posterior + + class TestAssignStepMethods: def test_bernoulli(self): """Test bernoulli distribution is assigned binary gibbs metropolis method"""