Skip to content

Commit 7bf58d5

Browse files
authored
Add public testing function to mock sample (#7761)
* push up pymc-marketing mock * run pre-commit * add small test * use positional arg for draws like in actual sample * better for mypy * provide the setup and breakdown for pytest fixtures * change name for testing convention * bit more explicit on the test * add to the documentation * use expand_dims method * add to the toc * alterations to docstrings * change format and provide links * link to the functions in the docstring
1 parent 5003e50 commit 7bf58d5

File tree

4 files changed

+178
-1
lines changed

4 files changed

+178
-1
lines changed

docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ API
2222
api/shape_utils
2323
api/backends
2424
api/misc
25+
api/testing
2526

2627
------------------
2728
Dimensionality

docs/source/api/testing.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
=======
2+
Testing
3+
=======
4+
5+
This submodule contains tools to help with testing PyMC code.
6+
7+
8+
.. currentmodule:: pymc.testing
9+
10+
.. autosummary::
11+
:toctree: generated/
12+
13+
mock_sample
14+
mock_sample_setup_and_teardown

pymc/testing.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytensor
2323
import pytensor.tensor as pt
2424

25+
from arviz import InferenceData
2526
from numpy import random as nr
2627
from numpy import testing as npt
2728
from pytensor.compile.mode import Mode
@@ -982,3 +983,115 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
982983
rvs = rvs_in_graph(vars)
983984
if rvs:
984985
raise AssertionError(f"RV found in graph: {rvs}")
986+
987+
988+
def mock_sample(draws: int = 10, **kwargs):
989+
"""Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`.
990+
991+
Useful for testing models that use pm.sample without running MCMC sampling.
992+
993+
Examples
994+
--------
995+
Using mock_sample with pytest
996+
997+
.. note::
998+
999+
Use :func:`pymc.testing.mock_sample_setup_and_teardown` directly for pytest fixtures.
1000+
1001+
.. code-block:: python
1002+
1003+
import pytest
1004+
1005+
import pymc as pm
1006+
from pymc.testing import mock_sample
1007+
1008+
1009+
@pytest.fixture(scope="module")
1010+
def mock_pymc_sample():
1011+
original_sample = pm.sample
1012+
pm.sample = mock_sample
1013+
1014+
yield
1015+
1016+
pm.sample = original_sample
1017+
1018+
"""
1019+
random_seed = kwargs.get("random_seed", None)
1020+
model = kwargs.get("model", None)
1021+
draws = kwargs.get("draws", draws)
1022+
n_chains = kwargs.get("chains", 1)
1023+
idata: InferenceData = pm.sample_prior_predictive(
1024+
model=model,
1025+
random_seed=random_seed,
1026+
draws=draws,
1027+
)
1028+
1029+
idata.add_groups(
1030+
posterior=(
1031+
idata["prior"]
1032+
.isel(chain=0)
1033+
.expand_dims({"chain": range(n_chains)})
1034+
.transpose("chain", "draw", ...)
1035+
)
1036+
)
1037+
del idata["prior"]
1038+
if "prior_predictive" in idata:
1039+
del idata["prior_predictive"]
1040+
return idata
1041+
1042+
1043+
def mock_sample_setup_and_teardown():
1044+
"""Set up and tear down mocking of PyMC sampling functions for testing.
1045+
1046+
This function is designed to be used with pytest fixtures to temporarily replace
1047+
PyMC's sampling functionality with faster alternatives for testing purposes.
1048+
1049+
Effects during the fixture's active period:
1050+
1051+
* Replaces :func:`pymc.sample` with :func:`pymc.testing.mock_sample`, which uses
1052+
prior predictive sampling instead of MCMC
1053+
* Replaces distributions:
1054+
* :class:`pymc.Flat` with :class:`pymc.Normal`
1055+
* :class:`pymc.HalfFlat` with :class:`pymc.HalfNormal`
1056+
* Automatically restores all original functions and distributions after the test completes
1057+
1058+
Examples
1059+
--------
1060+
Use with `pytest` to mock actual PyMC sampling in test suite.
1061+
1062+
.. code-block:: python
1063+
1064+
# tests/conftest.py
1065+
import pytest
1066+
import pymc as pm
1067+
from pymc.testing import mock_sample_setup_and_teardown
1068+
1069+
# Register as a pytest fixture
1070+
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)
1071+
1072+
1073+
# tests/test_model.py
1074+
# Use in a test function
1075+
def test_model_inference(mock_pymc_sample):
1076+
with pm.Model() as model:
1077+
x = pm.Normal("x", 0, 1)
1078+
# This will use mock_sample instead of actual MCMC
1079+
idata = pm.sample()
1080+
# Test with the inference data...
1081+
1082+
"""
1083+
import pymc as pm
1084+
1085+
original_flat = pm.Flat
1086+
original_half_flat = pm.HalfFlat
1087+
original_sample = pm.sample
1088+
1089+
pm.sample = mock_sample
1090+
pm.Flat = pm.Normal
1091+
pm.HalfFlat = pm.HalfNormal
1092+
1093+
yield
1094+
1095+
pm.sample = original_sample
1096+
pm.Flat = original_flat
1097+
pm.HalfFlat = original_half_flat

tests/test_testing.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515

1616
import pytest
1717

18-
from pymc.testing import Domain
18+
import pymc as pm
19+
20+
from pymc.testing import Domain, mock_sample, mock_sample_setup_and_teardown
21+
from tests.models import simple_normal
1922

2023

2124
@pytest.mark.parametrize(
@@ -32,3 +35,49 @@
3235
def test_domain(values, edges, expectation):
3336
with expectation:
3437
Domain(values, edges=edges)
38+
39+
40+
@pytest.mark.parametrize(
41+
"args, kwargs, expected_size",
42+
[
43+
pytest.param((), {}, (1, 10), id="default"),
44+
pytest.param((100,), {}, (1, 100), id="positional-draws"),
45+
pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"),
46+
pytest.param((100,), {"chains": 6}, (6, 100), id="chains"),
47+
],
48+
)
49+
def test_mock_sample(args, kwargs, expected_size) -> None:
50+
expected_chains, expected_draws = expected_size
51+
_, model, _ = simple_normal(bounded_prior=True)
52+
53+
with model:
54+
idata = mock_sample(*args, **kwargs)
55+
56+
assert "posterior" in idata
57+
assert "observed_data" in idata
58+
assert "prior" not in idata
59+
assert "posterior_predictive" not in idata
60+
assert "sample_stats" not in idata
61+
62+
assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws}
63+
64+
65+
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)
66+
67+
68+
@pytest.fixture(scope="function")
69+
def dummy_model() -> pm.Model:
70+
with pm.Model() as model:
71+
pm.Flat("flat")
72+
pm.HalfFlat("half_flat")
73+
74+
return model
75+
76+
77+
def test_fixture(mock_pymc_sample, dummy_model) -> None:
78+
with dummy_model:
79+
idata = pm.sample()
80+
81+
posterior = idata.posterior
82+
assert posterior.sizes == {"chain": 1, "draw": 10}
83+
assert (posterior["half_flat"] >= 0).all()

0 commit comments

Comments
 (0)