diff --git a/pymc/step_methods/external/__init__.py b/pymc/step_methods/external/__init__.py new file mode 100644 index 0000000000..f004924b2c --- /dev/null +++ b/pymc/step_methods/external/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""External samplers integration for PyMC.""" + +from pymc.step_methods.external.base import ExternalSampler +from pymc.step_methods.external.nutpie import NutPie + +__all__ = ["ExternalSampler", "NutPie"] diff --git a/pymc/step_methods/external/base.py b/pymc/step_methods/external/base.py new file mode 100644 index 0000000000..b61e037978 --- /dev/null +++ b/pymc/step_methods/external/base.py @@ -0,0 +1,124 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from arviz import InferenceData + +from pymc.step_methods.compound import BlockedStep, Competence + + +class ExternalSampler(BlockedStep, ABC): + """Base class for external samplers. + + External samplers manage their own MCMC loop rather than using PyMC's. + These samplers (like NutPie, BlackJax, etc.) are designed to run + their own efficient loop inside their implementation. + + Attributes + ---------- + is_external : bool + Flag indicating that this is an external sampler that needs + special handling in PyMC's sampling loops. + """ + + is_external = True + + def __init__( + self, + vars=None, + model=None, + **kwargs, + ): + """Initialize external sampler. + + Parameters + ---------- + vars : list, optional + Variables to be sampled + model : Model, optional + PyMC model + **kwargs + Sampler-specific arguments + """ + self.model = model + self._vars = vars + self._kwargs = kwargs + + @abstractmethod + def sample( + self, + draws: int, + tune: int = 1000, + chains: int = 4, + random_seed=None, + initvals=None, + progressbar=True, + cores=None, + **kwargs, + ) -> InferenceData: + """Run external sampler and return results as InferenceData. + + Parameters + ---------- + draws : int + Number of draws per chain + tune : int + Number of tuning draws per chain + chains : int + Number of chains to sample + random_seed : int or sequence, optional + Random seed(s) for reproducibility + initvals : dict or list of dict, optional + Initial values for variables + progressbar : bool + Whether to display progress bar + cores : int, optional + Number of CPU cores to use + **kwargs + Additional sampler-specific parameters + + Returns + ------- + InferenceData + ArviZ InferenceData object with sampling results + """ + pass + + def step(self, point): + """Do not use this method. External samplers use their own sampling loop. + + External samplers do not use PyMC's step() mechanism. + """ + raise NotImplementedError( + "External samplers use their own sampling loop rather than PyMC's step() method." + ) + + @staticmethod + def competence(var, has_grad): + """Determine competence level for sampling var. + + Parameters + ---------- + var : Variable + Variable to be sampled + has_grad : bool + Whether gradient information is available + + Returns + ------- + Competence + Enum indicating competence level for this variable + """ + return Competence.COMPATIBLE diff --git a/pymc/step_methods/external/nutpie.py b/pymc/step_methods/external/nutpie.py new file mode 100644 index 0000000000..c99bcf2100 --- /dev/null +++ b/pymc/step_methods/external/nutpie.py @@ -0,0 +1,242 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings + +from typing import Literal + +from arviz import InferenceData + +from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations +from pymc.model import Model +from pymc.step_methods.compound import Competence +from pymc.step_methods.external.base import ExternalSampler +from pymc.vartypes import continuous_types + +logger = logging.getLogger("pymc") + +try: + import nutpie + + # Check if it's actually installed and not just an empty mock module + NUTPIE_AVAILABLE = hasattr(nutpie, "compile_pymc_model") +except ImportError: + NUTPIE_AVAILABLE = False + + +class NutPie(ExternalSampler): + """NutPie No-U-Turn Sampler. + + This class provides an interface to the NutPie sampler, which is a high-performance + implementation of the No-U-Turn Sampler (NUTS). Unlike PyMC's native NUTS implementation, + NutPie samples chains sequentially in a single CPU, which can be more efficient for some + models. + + Parameters + ---------- + vars : list, optional + Variables to be sampled + model : Model, optional + PyMC model + backend : {"numba", "jax"}, default="numba" + Which backend to use for computation + target_accept : float, default=0.8 + Target acceptance rate for step size adaptation + max_treedepth : int, default=10 + Maximum tree depth for NUTS (passed as 'maxdepth' to NutPie) + **kwargs + Additional parameters passed to nutpie.sample() + + Notes + ----- + Requires the nutpie package to be installed: + pip install nutpie + """ + + name = "nutpie" + + def __init__( + self, + vars=None, + *, + model=None, + backend: Literal["numba", "jax"] = "numba", + target_accept: float = 0.8, + max_treedepth: int = 10, + **kwargs, + ): + """Initialize NutPie sampler.""" + if not NUTPIE_AVAILABLE: + raise ImportError("nutpie not found. Install it with: pip install nutpie") + + super().__init__(vars=vars, model=model) + + self.backend = backend + self.target_accept = target_accept + self.max_treedepth = max_treedepth + self.nutpie_kwargs = kwargs + + def sample( + self, + draws: int, + tune: int = 1000, + chains: int = 4, + random_seed=None, + initvals=None, + progressbar=True, + cores=None, + idata_kwargs=None, + compute_convergence_checks=True, + **kwargs, + ) -> InferenceData: + """Run NutPie sampler and return results as InferenceData. + + Parameters + ---------- + draws : int + Number of draws per chain + tune : int + Number of tuning draws per chain + chains : int + Number of chains to sample + random_seed : int or sequence, optional + Random seed(s) for reproducibility + initvals : dict or list of dict, optional + Initial values for variables (currently not used by NutPie) + progressbar : bool + Whether to display progress bar + cores : int, optional + Number of CPU cores to use (ignored by NutPie) + idata_kwargs : dict, optional + Additional arguments for arviz.InferenceData conversion + compute_convergence_checks : bool + Whether to compute convergence diagnostics + **kwargs + Additional sampler-specific parameters + + Returns + ------- + InferenceData + ArviZ InferenceData object with sampling results + """ + model = kwargs.pop("model", self.model) + if model is None: + model = Model.get_context() + + # Handle variables + vars = kwargs.pop("vars", self._vars) + if vars is None: + vars = model.value_vars + + # Create a NutPie model + logger.info("Compiling NutPie model") + nutpie_model = nutpie.compile_pymc_model( + model, + backend=self.backend, + ) + + # Set up sampling parameters - NutPie does this internally + # Keep these for other nutpie parameters to pass + nuts_kwargs = { + **self.nutpie_kwargs, + **kwargs, + } + + if initvals is not None: + warnings.warn( + "`initvals` are currently not passed to nutpie sampler. " + "Use `init_mean` kwarg following nutpie specification instead.", + UserWarning, + ) + + # Set up random seed + if random_seed is not None: + nuts_kwargs["seed"] = random_seed + + # Run the sampler + logger.info( + f"Running NutPie sampler with {chains} chains, {tune} tuning steps, and {draws} draws" + ) + + # Add target acceptance and max tree depth + nutpie_kwargs = { + "target_accept": self.target_accept, + "maxdepth": self.max_treedepth, + **nuts_kwargs, + } + + # Update parameter names to match NutPie's API + if "progressbar" in nutpie_kwargs: + nutpie_kwargs["progress_bar"] = nutpie_kwargs.pop("progressbar") + + # Pass progressbar from the sample function arguments + if progressbar is not None: + nutpie_kwargs["progress_bar"] = progressbar + + # Call NutPie's sample function + nutpie_trace = nutpie.sample( + nutpie_model, + draws=draws, + tune=tune, + chains=chains, + **nutpie_kwargs, + ) + + # Convert to InferenceData + if idata_kwargs is None: + idata_kwargs = {} + + # Extract relevant variables and data for InferenceData + coords, dims = coords_and_dims_for_inferencedata(model) + constants_data = find_constants(model) + observed_data = find_observations(model) + + # Always include sampler stats + if "include_sampler_stats" not in idata_kwargs: + idata_kwargs["include_sampler_stats"] = True + + # NutPie already returns an InferenceData object + idata = nutpie_trace + + # Set tuning steps attribute if possible + try: + idata.posterior.attrs["tuning_steps"] = tune + except (AttributeError, KeyError): + logger.warning("Could not set tuning_steps attribute on InferenceData") + + # Skip compute_convergence_checks for now + # NutPie's InferenceData structure is different from PyMC's expectations + + return idata + + @staticmethod + def competence(var, has_grad): + """Determine competence level for sampling var. + + Parameters + ---------- + var : Variable + Variable to be sampled + has_grad : bool + Whether gradient information is available + + Returns + ------- + Competence + Enum indicating competence level for this variable + """ + if var.dtype in continuous_types and has_grad: + return Competence.IDEAL + return Competence.INCOMPATIBLE diff --git a/tests/step_methods/test_external.py b/tests/step_methods/test_external.py new file mode 100644 index 0000000000..5366676425 --- /dev/null +++ b/tests/step_methods/test_external.py @@ -0,0 +1,112 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from pymc import Model, Normal, sample +from pymc.step_methods.external import NutPie +from pymc.step_methods.external.nutpie import NUTPIE_AVAILABLE + + +@pytest.mark.skipif(not NUTPIE_AVAILABLE, reason="NutPie not installed") +def test_nutpie_integration(): + """Test basic usage of NutPie as a PyMC step method.""" + with Model() as model: + x = Normal("x", mu=0, sigma=1) + + # Create NutPie sampler with numba backend + nutpie_sampler = NutPie(backend="numba") + + # NutPie sampler has the is_external attribute set to True + + # Sample using external sampler + trace = sample( + draws=10, # Use fewer draws for faster testing + tune=10, + step=nutpie_sampler, + chains=1, # Use just one chain for simplicity + random_seed=42, + progressbar=False, + ) + + # Check that the sampling worked + assert "x" in trace.posterior + assert trace.posterior.x.shape == (1, 10) + + # Check that the sampler stats were recorded + expected_stats = ["diverging", "energy"] + for stat in expected_stats: + assert stat in trace.sample_stats + + +@pytest.mark.skipif(not NUTPIE_AVAILABLE, reason="NutPie not installed") +def test_nutpie_jax_backend(): + """Test NutPie with JAX backend.""" + try: + import importlib.util + + jax_available = importlib.util.find_spec("jax") is not None + except ImportError: + jax_available = False + + if not jax_available: + pytest.skip("JAX not installed") + + with Model() as model: + x = Normal("x", mu=0, sigma=1) + + # Create NutPie sampler with JAX backend + nutpie_sampler = NutPie(backend="jax") + + # Sample using external sampler + trace = sample( + draws=10, + tune=10, + step=nutpie_sampler, + chains=1, + random_seed=42, + progressbar=False, + ) + + # Check that the sampling worked + assert "x" in trace.posterior + assert trace.posterior.x.shape == (1, 10) + + +@pytest.mark.skipif(not NUTPIE_AVAILABLE, reason="NutPie not installed") +def test_nutpie_custom_params(): + """Test NutPie with custom parameters.""" + with Model() as model: + x = Normal("x", mu=0, sigma=1) + + # Create NutPie sampler with custom parameters + nutpie_sampler = NutPie( + backend="numba", + target_accept=0.9, + max_treedepth=8, + ) + + # Sample using external sampler + trace = sample( + draws=10, + tune=10, + step=nutpie_sampler, + chains=1, + random_seed=42, + progressbar=False, + ) + + # Check that the sampling worked + assert "x" in trace.posterior + assert trace.posterior.x.shape == (1, 10)