Skip to content

Commit 950c18e

Browse files
committed
Merge pull request pymc-devs#586 from maahnman/master
New example and additional keyword for autocorrplot
2 parents 077190f + 7ac2304 commit 950c18e

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Similar to disaster_model.py, but for arbitrary
3+
determinsitics which are not not working with Theano.
4+
Note that gradient based samplers will not work.
5+
"""
6+
7+
8+
from pymc import *
9+
import theano.tensor as t
10+
from numpy import arange, array, ones, concatenate, empty
11+
from numpy.random import randint
12+
13+
__all__ = ['disasters_data', 'switchpoint', 'early_mean', 'late_mean', 'rate',
14+
'disasters']
15+
16+
# Time series of recorded coal mining disasters in the UK from 1851 to 1962
17+
disasters_data = array([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
18+
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
19+
2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
20+
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
21+
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
22+
3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
23+
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
24+
years = len(disasters_data)
25+
26+
#here is the trick
27+
@theano.compile.ops.as_op(itypes=[t.lscalar, t.dscalar, t.dscalar],otypes=[t.dvector])
28+
def rateFunc(switchpoint,early_mean, late_mean):
29+
''' Concatenate Poisson means '''
30+
out = empty(years)
31+
out[:switchpoint] = early_mean
32+
out[switchpoint:] = late_mean
33+
return out
34+
35+
36+
with Model() as model:
37+
38+
# Prior for distribution of switchpoint location
39+
switchpoint = DiscreteUniform('switchpoint', lower=0, upper=years)
40+
# Priors for pre- and post-switch mean number of disasters
41+
early_mean = Exponential('early_mean', lam=1.)
42+
late_mean = Exponential('late_mean', lam=1.)
43+
44+
# Allocate appropriate Poisson rates to years before and after current
45+
# switchpoint location
46+
idx = arange(years)
47+
#theano style:
48+
#rate = switch(switchpoint >= idx, early_mean, late_mean)
49+
#non-theano style
50+
rate = rateFunc(switchpoint, early_mean, late_mean)
51+
52+
# Data likelihood
53+
disasters = Poisson('disasters', rate, observed=disasters_data)
54+
55+
# Initial values for stochastic nodes
56+
start = {'early_mean': 2., 'late_mean': 3.}
57+
58+
# Use slice sampler for means
59+
step1 = Slice([early_mean, late_mean])
60+
# Use Metropolis for switchpoint, since it accomodates discrete variables
61+
step2 = Metropolis([switchpoint])
62+
63+
# njobs>1 works only with most recent (mid August 2014) Thenao version:
64+
# https://github.com/Theano/Theano/pull/2021
65+
tr = sample(1000, tune=500, start=start, step=[step1, step2],njobs=1)
66+
traceplot(tr)

pymc/plots.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def kde2plot(x, y, grid=200):
129129
return f
130130

131131

132-
def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
132+
def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
133133
"""Bar plot of the autocorrelation function for a trace"""
134134
import matplotlib.pyplot as plt
135135
if fontmap is None:
@@ -148,7 +148,7 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100):
148148

149149
for i, v in enumerate(vars):
150150
for j in range(chains):
151-
d = np.squeeze(trace.get_values(v, chains=[j]))
151+
d = np.squeeze(trace.get_values(v, chains=[j],burn=burn,thin=thin))
152152

153153
ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)
154154

0 commit comments

Comments
 (0)