diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 83b56aa4ed..f6ebbe8513 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -3,6 +3,7 @@ ## PyMC3 3.10.1 (on deck) ### Maintenance +- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318) - Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)). ## PyMC3 3.10.0 (7 December 2020) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 4ae11ccf8f..47cbcac381 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -1508,6 +1508,14 @@ def _mp_sample( def _choose_chains(traces, tune): + """ + Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized. + + We get here after a ``KeyboardInterrupt``, and so the different + traces have different lengths. We therefore pick the number of + traces such that (number of traces) * (length of shortest trace) + is maximised. + """ if tune is None: tune = 0 @@ -1518,22 +1526,13 @@ def _choose_chains(traces, tune): if not sum(lengths): raise ValueError("Not enough samples to build a trace.") - idxs = np.argsort(lengths)[::-1] + idxs = np.argsort(lengths) l_sort = np.array(lengths)[idxs] - final_length = l_sort[0] - last_total = 0 - for i, length in enumerate(l_sort): - total = (i + 1) * length - if total < last_total: - use_until = i - break - last_total = total - final_length = length - else: - use_until = len(lengths) + use_until = np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1]) + final_length = l_sort[use_until] - return [traces[idx] for idx in idxs[:use_until]], final_length + tune + return [traces[idx] for idx in idxs[use_until:]], final_length + tune def stop_tuning(step): diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index a167b4f759..cb5f9806ac 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -30,6 +30,7 @@ import pymc3 as pm +from pymc3.backends.ndarray import NDArray from pymc3.exceptions import IncorrectArgumentsError, SamplingError from pymc3.tests.helpers import SeededTest from pymc3.tests.models import simple_init @@ -299,6 +300,33 @@ def test_partial_trace_sample(): trace = pm.sample(trace=[a]) +@pytest.mark.parametrize( + "n_points, tune, expected_length, expected_n_traces", + [ + ((5, 2, 2), 0, 2, 3), + ((6, 1, 1), 1, 6, 1), + ], +) +def test_choose_chains(n_points, tune, expected_length, expected_n_traces): + with pm.Model() as model: + a = pm.Normal("a", mu=0, sigma=1) + trace_0 = NDArray(model) + trace_1 = NDArray(model) + trace_2 = NDArray(model) + trace_0.setup(n_points[0], 1) + trace_1.setup(n_points[1], 1) + trace_2.setup(n_points[2], 1) + for _ in range(n_points[0]): + trace_0.record({"a": 0}) + for _ in range(n_points[1]): + trace_1.record({"a": 0}) + for _ in range(n_points[2]): + trace_2.record({"a": 0}) + traces, length = pm.sampling._choose_chains([trace_0, trace_1, trace_2], tune=tune) + assert length == expected_length + assert expected_n_traces == len(traces) + + @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") class TestNamedSampling(SeededTest): def test_shared_named(self):