|
| 1 | +import os |
1 | 2 | import unittest
|
2 | 3 |
|
3 | 4 | from .checks import close_to
|
|
8 | 9 | Metropolis, Slice, CompoundStep,
|
9 | 10 | MultivariateNormalProposal, HamiltonianMC)
|
10 | 11 | from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical
|
11 |
| -from numpy.testing import assert_almost_equal |
| 12 | + |
| 13 | +from numpy.testing import assert_array_almost_equal |
12 | 14 | import numpy as np
|
| 15 | +from tqdm import tqdm |
13 | 16 |
|
14 | 17 |
|
15 | 18 | class TestStepMethods(object): # yield test doesn't work subclassing unittest.TestCase
|
| 19 | + master_samples = { |
| 20 | + Slice: np.array([ |
| 21 | + -8.13087389e-01, -3.08921856e-01, -6.79377098e-01, 6.50812585e-01, -7.63577596e-01, |
| 22 | + -8.13199793e-01, -1.63823548e+00, -7.03863676e-02, 2.05107771e+00, 1.68598170e+00, |
| 23 | + 6.92463695e-01, -7.75120766e-01, -1.62296463e+00, 3.59722423e-01, -2.31421712e-01, |
| 24 | + -7.80686956e-02, -6.05860731e-01, -1.13000202e-01, 1.55675942e-01, -6.78527612e-01, |
| 25 | + 6.31052333e-01, 6.09012517e-01, -1.56621643e+00, 5.04330883e-01, 3.14824082e-03, |
| 26 | + -1.31287073e+00, 4.10706927e-01, 8.93815792e-01, 8.19317020e-01, 3.71900919e-01, |
| 27 | + -2.62067312e+00, -3.47616592e+00, 1.50335041e+00, -1.05993351e+00, 2.41571723e-01, |
| 28 | + -1.06258156e+00, 5.87999429e-01, -1.78480091e-01, -3.60278680e-01, 1.90615274e-01, |
| 29 | + -1.24399204e-01, 4.03845589e-01, -1.47797573e-01, 7.90445804e-01, -1.21043819e+00, |
| 30 | + -1.33964776e+00, 1.36366329e+00, -7.50175388e-01, 9.25241839e-01, -4.17493767e-01, |
| 31 | + 1.85311339e+00, -2.49715343e+00, -3.18571692e-01, -1.49099668e+00, -2.62079621e-01, |
| 32 | + -5.82376852e-01, -2.53033395e+00, 2.07580503e+00, -9.82615856e-01, 6.00517782e-01, |
| 33 | + -9.83941620e-01, -1.59014118e+00, -1.83931394e-03, -4.71163466e-01, 1.90073737e+00, |
| 34 | + -2.08929125e-01, -6.98388847e-01, 1.64502092e+00, -1.19525944e+00, 1.44424109e+00, |
| 35 | + 1.52974876e+00, -5.70140077e-01, 5.08633322e-01, -1.70862492e-02, -1.69887948e-01, |
| 36 | + 5.19760297e-01, -4.15149647e-01, 8.63685174e-02, -3.66805233e-01, -9.24988952e-01, |
| 37 | + 2.33307122e+00, -2.60391496e-01, -5.86271814e-01, -5.01297170e-01, -1.53866195e+00, |
| 38 | + 5.71285373e-01, -1.30571830e+00, 8.59587795e-01, 6.72170694e-01, 9.12433943e-01, |
| 39 | + 7.04959179e-01, 8.37863464e-01, -5.24200836e-01, 1.28261340e+00, 9.08774240e-01, |
| 40 | + 8.80566763e-01, 7.82911967e-01, 8.01843432e-01, 7.09251098e-01, 5.73803618e-01]), |
| 41 | + HamiltonianMC: np.array([ |
| 42 | + -1.56440708e-03, -2.37766120e-03, -6.95819902e-03, -4.88882715e-03, -6.54928517e-03, |
| 43 | + -3.38653286e-03, -1.99381372e-03, -1.25904805e-03, -2.97173572e-04, -4.67391216e-04, |
| 44 | + -2.03821237e-03, -1.33693751e-04, -2.17293248e-03, -4.11675406e-03, -4.23091782e-03, |
| 45 | + -7.34120851e-03, -8.43726968e-03, -7.86976139e-03, -3.89551467e-03, -3.00788956e-03, |
| 46 | + -3.82420513e-03, -1.35604792e-03, -2.49066947e-04, 4.03633859e-04, 9.34321408e-05, |
| 47 | + 1.77722574e-03, 1.63761359e-03, 2.86208401e-03, -1.72243038e-04, 1.86863525e-03, |
| 48 | + 1.76740215e-03, 1.79169049e-03, 1.07164602e-03, 1.41264547e-03, 2.49563456e-03, |
| 49 | + 1.76639216e-03, 3.01570589e-03, 1.44186424e-04, 1.45073846e-03, 2.95031617e-04, |
| 50 | + -1.28811479e-04, -7.35945905e-04, -6.00689088e-04, 2.75468405e-04, 1.05245800e-03, |
| 51 | + 1.18892307e-03, 6.01165842e-04, 1.21016955e-03, -2.06751271e-03, -8.41426458e-04, |
| 52 | + 6.09905557e-04, 2.92765303e-03, 4.15216348e-03, 2.71863268e-03, 3.42922082e-03, |
| 53 | + 7.53890188e-03, 7.97507867e-03, 8.27371677e-03, 9.77811135e-03, 9.99705714e-03, |
| 54 | + 1.13996054e-02, 1.15745874e-02, 1.08182152e-02, 1.08277279e-02, 9.32254191e-03, |
| 55 | + 8.59914793e-03, 8.43927425e-03, 1.01570101e-02, 9.74607039e-03, 9.82868496e-03, |
| 56 | + 1.01745777e-02, 1.19312194e-02, 1.53760522e-02, 1.38691940e-02, 1.40131760e-02, |
| 57 | + 1.46184561e-02, 1.74382675e-02, 1.84241543e-02, 2.06913002e-02, 1.83520531e-02, |
| 58 | + 2.03072531e-02, 1.72912752e-02, 1.38959101e-02, 1.21933473e-02, 1.05084488e-02, |
| 59 | + 9.00532336e-03, 9.25863206e-03, 1.23618461e-02, 1.20207293e-02, 1.09334818e-02, |
| 60 | + 1.16528011e-02, 1.29967126e-02, 1.38940942e-02, 1.11408833e-02, 1.09263348e-02, |
| 61 | + 1.06521352e-02, 1.01622526e-02, 1.21998547e-02, 1.00880470e-02, 9.94787795e-03]), |
| 62 | + Metropolis: np.array([ |
| 63 | + 1.62434536, 1.01258895, 0.4844172, -0.58855142, 1.15626034, 0.39505344, 1.85716138, |
| 64 | + -0.20297933, -0.20297933, -0.20297933, -0.20297933, -1.08083775, -1.08083775, |
| 65 | + 0.06388596, 0.96474191, 0.28101405, 0.01312597, 0.54348144, -0.14369126, -0.98889691, |
| 66 | + -0.98889691, -0.75448121, -0.94631676, -0.94631676, -0.89550901, -0.89550901, |
| 67 | + -0.77535005, -0.15814694, 0.14202338, -0.21022647, -0.4191207, 0.16750249, 0.45308981, |
| 68 | + 1.33823098, 1.8511608, 1.55306796, 1.55306796, 1.55306796, 1.55306796, 0.15657163, |
| 69 | + 0.3166087, 0.3166087, 0.3166087, 0.3166087, 0.54670343, 0.54670343, 0.32437529, |
| 70 | + 0.12361722, 0.32191694, 0.44092559, 0.56274686, 0.56274686, 0.18746191, 0.18746191, |
| 71 | + -0.15639177, -0.11279491, -0.11279491, -0.11279491, -1.20770676, -1.03832432, |
| 72 | + -0.29776787, -1.25146848, -1.25146848, -0.93630908, -0.5857631, -0.5857631, |
| 73 | + -0.62445861, -0.62445861, -0.64907557, -0.64907557, -0.64907557, 0.58708846, |
| 74 | + -0.61217957, 0.25116575, 0.25116575, 0.80170324, 1.59451011, 0.97097938, 1.77284041, |
| 75 | + 1.81940771, 1.81940771, 1.81940771, 1.81940771, 1.95710892, 2.18960348, 2.18960348, |
| 76 | + 2.18960348, 2.18960348, 2.63096792, 2.53081269, 2.5482221, 1.42620337, 0.90910891, |
| 77 | + -0.08791792, 0.40729341, 0.23259025, 0.23259025, 0.23259025, 2.76091595, 2.51228118]), |
| 78 | + NUTS: np.array([ |
| 79 | + 0.68819657, 0.1767813, -0.59467679, -0.64216066, 1.63681405, 2.13404699, 0.03126563, |
| 80 | + 0.31817152, 0.31817152, 0.40191527, 0.40191527, 0.99220141, 0.93036804, -0.41228181, |
| 81 | + -1.80465851, -1.70577291, 0.19406438, 0.19406438, -0.03965181, -0.76135744, |
| 82 | + 0.70023098, 1.07183677, 1.07183677, 0.2829979, 1.13524135, -0.26461224, |
| 83 | + -0.39442329, -1.04109657, 0.79971205, 0.79971205, 0.96839778, 0.91868626, |
| 84 | + 0.19468837, 0.19468837, -0.67755668, -0.67755668, -0.43722432, 0.12072881, |
| 85 | + 0.6267432, 0.6861771, 0.4669198, 0.4669198, -0.08143768, 0.27691068, 0.11510718, |
| 86 | + 2.29821426, 2.18308403, 1.16618069, -0.45615197, -0.45615197, -0.37076172, |
| 87 | + -0.37076172, -0.38889599, 0.36200553, -0.55179735, -0.55179735, -0.18946703, |
| 88 | + 1.11552335, 0.98985795, 0.98985795, 1.00313687, -0.18458164, 0.44025584, 0.97610126, |
| 89 | + -0.1558578, -0.1558578, -0.01247235, -0.08303131, 0.52019377, -1.52329796, |
| 90 | + -1.72856248, -1.19049049, -1.19049049, -0.8651521, -0.36421118, -0.40590409, |
| 91 | + -0.78925074, -0.53960924, -0.53960924, 0.1069186, 0.40849997, 0.1560954, |
| 92 | + 0.35461684, 0.35461684, -0.83935418, -0.85295353, -0.13990269, -0.1412904, |
| 93 | + -0.1412904, -0.30071575, -0.296461, 0.06540186, -0.15145479, -0.15145479, |
| 94 | + -0.21406771, -0.21533218, 0.06833495, 0.06833495, -0.18763595, 0.34138144]), |
| 95 | + } |
| 96 | + |
| 97 | + def test_sample_exact(self): |
| 98 | + for step_method in self.master_samples: |
| 99 | + yield self.check_trace, step_method |
| 100 | + |
| 101 | + def check_trace(self, step_method): |
| 102 | + """Tests whether the trace for step methods is exactly the same as on master. |
| 103 | +
|
| 104 | + Code changes that effect how random numbers are drawn may change this, and require |
| 105 | + `master_samples` to be updated, but such changes should be noted and justified in the |
| 106 | + commit. |
| 107 | +
|
| 108 | + This method may also be used to benchmark step methods across commits, by running, for |
| 109 | + example |
| 110 | +
|
| 111 | + ``` |
| 112 | + BENCHMARK=100000 ./scripts/test.sh -s pymc3/tests/test_step.py:TestStepMethods |
| 113 | + ``` |
| 114 | +
|
| 115 | + on multiple commits. |
| 116 | + """ |
| 117 | + test_steps = 100 |
| 118 | + n_steps = int(os.getenv('BENCHMARK', 100)) |
| 119 | + benchmarking = (n_steps != test_steps) |
| 120 | + if benchmarking: |
| 121 | + tqdm.write('Benchmarking {} with {:,d} samples'.format(step_method.__name__, n_steps)) |
| 122 | + else: |
| 123 | + tqdm.write('Checking {} has same trace as on master'.format(step_method.__name__)) |
| 124 | + with Model(): |
| 125 | + Normal('x', mu=0, sd=1) |
| 126 | + trace = sample(n_steps, step=step_method(), random_seed=1) |
| 127 | + |
| 128 | + if not benchmarking: |
| 129 | + assert_array_almost_equal(trace.get_values('x'), self.master_samples[step_method]) |
| 130 | + |
16 | 131 | def check_stat(self, check, trace):
|
17 | 132 | for (var, stat, value, bound) in check:
|
18 | 133 | s = stat(trace[var][2000:], axis=0)
|
@@ -60,8 +175,8 @@ def test_step_categorical(self):
|
60 | 175 | ('x', np.std, unc, unc / 10.))
|
61 | 176 | with model:
|
62 | 177 | steps = (
|
63 |
| - CategoricalGibbsMetropolis(model.x, proposal = 'uniform'), |
64 |
| - CategoricalGibbsMetropolis(model.x, proposal = 'proportional'), |
| 178 | + CategoricalGibbsMetropolis(model.x, proposal='uniform'), |
| 179 | + CategoricalGibbsMetropolis(model.x, proposal='proportional'), |
65 | 180 | )
|
66 | 181 | for step in steps:
|
67 | 182 | trace = sample(8000, step=step, start=start, model=model, random_seed=1)
|
|
0 commit comments