Skip to content

Add public testing function to mock sample #7761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 1, 2025
Merged
52 changes: 52 additions & 0 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from collections.abc import Callable, Sequence
from typing import Any

from arviz import InferenceData
from xarray import DataArray
import numpy as np
import pytensor
import pytensor.tensor as pt
Expand Down Expand Up @@ -982,3 +984,53 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
rvs = rvs_in_graph(vars)
if rvs:
raise AssertionError(f"RV found in graph: {rvs}")


def mock_sample(*args, **kwargs):
"""Mock the pm.sample function by returning prior predictive samples as posterior.

Useful for testing models that use pm.sample without running MCMC sampling.

Examples
--------
Using mock_sample with pytest

.. code-block:: python

import pytest

import pymc as pm
from pymc.testing import mock_sample


@pytest.fixture(scope="module")
def mock_pymc_sample():
original_sample = pm.sample
pm.sample = mock_sample

yield

pm.sample = original_sample

"""
random_seed = kwargs.get("random_seed", None)
model = kwargs.get("model", None)
draws = kwargs.get("draws", 10)
n_chains = kwargs.get("chains", 1)
idata: InferenceData = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
draws=draws,
)

expanded_chains = DataArray(
np.ones(n_chains),
coords={"chain": np.arange(n_chains)},
)
idata.add_groups(
posterior=(idata.prior.mean("chain") * expanded_chains).transpose("chain", "draw", ...)
)
del idata.prior
if "prior_predictive" in idata:
del idata.prior_predictive
return idata
Loading