Skip to content

Commit 3c18c4d

Browse files
Vaibhavdixit02ColCarroll
authored andcommitted
Changes in test for random kwarg in DensityDist (#2840)
* Added test * Test scipy distribution compatibility with DensityDist * Updated RELEASE-NOTES.md * Changes in test for random method in DensityDist * Added docstring for DensityDist * Update tests for random method with DensityDist
1 parent 38bfc65 commit 3c18c4d

File tree

3 files changed

+51
-44
lines changed

3 files changed

+51
-44
lines changed

RELEASE-NOTES.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### New features
66

7+
<<<<<<< refs/remotes/pymc-devs/master
78
- Add `logit_p` keyword to `pm.Bernoulli`, so that users can specify the logit of the success probability. This is faster and more stable than using `p=tt.nnet.sigmoid(logit_p)`.
89
- Add `random` keyword to `pm.DensityDist` thus enabling users to pass custom random method which in turn makes sampling from a `DensityDist` possible.
910
- Effective sample size computation is updated. The estimation uses Geyer's initial positive sequence, which no longer truncates the autocorrelation series inaccurately. `pm.diagnostics.effective_n` now can reports N_eff>N.
@@ -19,6 +20,18 @@
1920
- The bandwidth for KDE plots is computed using a modified version of Scott's rule. The new version uses entropy instead of standard
2021
deviation. This works better for multimodal distributions. Functions using KDE plots has a new argument `bw` controlling the bandwidth.
2122

23+
=======
24+
- Add `logit_p` keyword to `pm.Bernoulli`, so that users can specify the logit
25+
of the success probability. This is faster and more stable than using
26+
`p=tt.nnet.sigmoid(logit_p)`.
27+
- Add `random` keyword to `pm.DensityDist` thus enabling users to pass custom random method
28+
which in turn makes sampling from a `DensityDist` possible.
29+
30+
### Fixes
31+
32+
- `VonMises` does not overflow for large values of kappa. i0 and i1 have been removed and we now use
33+
log_i0 to compute the logp.
34+
>>>>>>> Changes in test for random method in DensityDist
2235
2336
### Deprecations
2437

pymc3/distributions/distribution.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,22 @@ def __init__(self, shape=(), dtype=None, defaults=('median', 'mean', 'mode'),
176176

177177

178178
class DensityDist(Distribution):
179-
"""Distribution based on a given log density function."""
179+
"""Distribution based on a given log density function.
180+
181+
A distribution with the passed log density function is created.
182+
Requires a custom random function passed as kwarg `random` to
183+
enable sampling.
184+
185+
Example:
186+
--------
187+
.. code-block:: python
188+
with pm.Model():
189+
mu = pm.Normal('mu',0,1)
190+
normal_dist = pm.Normal.dist(mu, 1)
191+
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
192+
trace = pm.sample(100)
193+
194+
"""
180195

181196
def __init__(self, logp, shape=(), dtype=None, testval=0, random=None, *args, **kwargs):
182197
if dtype is None:

pymc3/tests/test_distributions_random.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -741,46 +741,25 @@ def ref_rand(size, w, mu, sd):
741741
size=1000,
742742
ref_rand=ref_rand)
743743

744-
def test_density_dist(self):
745-
def ref_rand(size, mu, sd):
746-
return st.norm.rvs(size=size, loc=mu, scale=sd)
747-
748-
class TestDensityDist(pm.DensityDist):
749-
750-
def __init__(self, **kwargs):
751-
norm_dist = pm.Normal.dist()
752-
super(TestDensityDist, self).__init__(logp=norm_dist.logp, random=norm_dist.random)
753-
754-
pymc3_random(TestDensityDist, {},ref_rand=ref_rand)
755-
756-
def check_model_samplability(self):
757-
model = pm.Model()
758-
with model:
759-
normal_dist = pm.Normal.dist()
760-
density_dist = pm.DensityDist('density_dist', normal_dist.logp, random=normal_dist.random)
761-
step = pm.Metropolis()
762-
trace = pm.sample(100, step, tuning=0)
763-
764-
try:
765-
ppc = pm.sample_ppc(trace, samples=500, model=model, size=100)
766-
if len(ppc) == 0:
767-
npt.assert_true(len(ppc) == 0, 'length of ppc sample is zero')
768-
except:
769-
assert False
770-
771-
def check_scipy_distributions(self):
772-
model = pm.Model()
773-
with model:
774-
norm_dist_logp = st.norm.logpdf
775-
norm_dist_random = np.random.normal
776-
density_dist = pm.DensityDist('density_dist', normal_dist_logp, random=normal_dist_random)
777-
step = pm.Metropolis()
778-
trace = pm.sample(100, step, tuning=0)
779-
780-
try:
781-
ppc = pm.sample_ppc(trace, samples=500, model=model, size=100)
782-
if len(ppc) == 0:
783-
npt.assert_true(len(ppc) == 0, 'length of ppc sample is zero')
784-
except:
785-
assert False
786-
744+
def test_density_dist_with_random_sampleable():
745+
with pm.Model() as model:
746+
mu = pm.Normal('mu',0,1)
747+
normal_dist = pm.Normal.dist(mu, 1)
748+
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
749+
trace = pm.sample(100)
750+
751+
samples = 500
752+
ppc = pm.sample_ppc(trace, samples=samples, model=model, size=100)
753+
assert len(ppc['density_dist']) == samples
754+
755+
756+
def test_density_dist_without_random_not_sampleable():
757+
with pm.Model() as model:
758+
mu = pm.Normal('mu',0,1)
759+
normal_dist = pm.Normal.dist(mu, 1)
760+
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
761+
trace = pm.sample(100)
762+
763+
samples = 500
764+
with pytest.raises(ValueError):
765+
pm.sample_ppc(trace, samples=samples, model=model, size=100)

0 commit comments

Comments
 (0)