diff --git a/pymc3/sampling.py b/pymc3/sampling.py index a9771d3e55..a97e8dd707 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -414,15 +414,15 @@ def sample( """ model = modelcontext(model) if start is None: - start = model.test_point + check_start_vals(model.test_point, model) else: if isinstance(start, dict): update_start_vals(start, model.test_point, model) else: for chain_start_vals in start: update_start_vals(chain_start_vals, model.test_point, model) + check_start_vals(start, model) - check_start_vals(start, model) if cores is None: cores = min(4, _cpu_count()) @@ -490,9 +490,9 @@ def sample( progressbar=progressbar, **kwargs, ) - check_start_vals(start_, model) if start is None: start = start_ + check_start_vals(start, model) except (AttributeError, NotImplementedError, tg.NullTypeGradError): # gradient computation failed _log.info("Initializing NUTS failed. " "Falling back to elementwise auto-assignment.") diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 2185542f17..2ccaf80623 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import ExitStack as does_not_raise from itertools import combinations from typing import Tuple import numpy as np @@ -25,7 +26,7 @@ import theano from pymc3.tests.models import simple_init from pymc3.tests.helpers import SeededTest -from pymc3.exceptions import IncorrectArgumentsError +from pymc3.exceptions import IncorrectArgumentsError, SamplingError from scipy import stats import pytest @@ -761,6 +762,35 @@ def test_exec_nuts_init(method): assert "a" in start[0] and "b_log__" in start[0] +@pytest.mark.parametrize( + "init, start, expectation", + [ + ("auto", None, pytest.raises(SamplingError)), + ("jitter+adapt_diag", None, pytest.raises(SamplingError)), + ("auto", {"x": 0}, does_not_raise()), + ("jitter+adapt_diag", {"x": 0}, does_not_raise()), + ("adapt_diag", None, does_not_raise()), + ], +) +def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch): + # This test tries to check whether the starting points returned by init_nuts are actually + # being used when pm.sample() is called without specifying an explicit start point (see + # https://github.com/pymc-devs/pymc3/pull/4285). + def _mocked_init_nuts(*args, **kwargs): + if init == "adapt_diag": + start_ = [{"x": np.array(0.79788456)}] + else: + start_ = [{"x": np.array(-0.04949886)}] + _, step = pm.init_nuts(*args, **kwargs) + return start_, step + + monkeypatch.setattr("pymc3.sampling.init_nuts", _mocked_init_nuts) + with pm.Model() as m: + x = pm.HalfNormal("x", transform=None) + with expectation: + pm.sample(tune=1, draws=0, chains=1, init=init, start=start) + + @pytest.fixture(scope="class") def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]: with pm.Model() as pmodel: