Skip to content

Commit 9563d8a

Browse files
started splitting up dist. classes; some small testval test changes
1 parent 9f99178 commit 9563d8a

File tree

5 files changed

+78
-97
lines changed

5 files changed

+78
-97
lines changed

pymc3/distributions/continuous.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99

1010
import numpy as np
1111
import theano.tensor as tt
12+
from theano.gof.op import get_test_value
1213
from scipy import stats
1314
import warnings
1415

1516
from . import transforms
1617
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1
17-
from .distribution import UnivariateContinuous, draw_values, generate_samples
18+
from .distribution import Univariate, Continuous, draw_values, generate_samples
1819

1920
__all__ = ['Uniform', 'Flat', 'Normal', 'Beta', 'Exponential', 'Laplace',
2021
'StudentT', 'Cauchy', 'HalfCauchy', 'Gamma', 'Weibull',
@@ -23,16 +24,24 @@
2324
'VonMises', 'SkewNormal']
2425

2526

26-
class PositiveUnivariateContinuous(UnivariateContinuous):
27+
class PositiveUnivariateContinuous(Univariate, Continuous):
2728
"""Base class for positive univariate continuous distributions"""
2829

2930
def __init__(self, *args, **kwargs):
3031
transform = kwargs.get('transform', transforms.log)
3132
super(PositiveUnivariateContinuous, self).__init__(transform=transform,
3233
*args, **kwargs)
34+
# TODO: is there a better way to use Theano's
35+
# existing `...tag.test_value` mechanics?
36+
ndim_sum = self.ndim_supp + self.ndim_ind + self.ndim_reps
37+
if self.testval is None:
38+
if ndim_sum == 0:
39+
self.testval = 0.5
40+
else:
41+
self.testval = get_test_value(tt.alloc(*((0.5,) + self.shape)))
3342

3443

35-
class UnitUnivariateContinuous(UnivariateContinuous):
44+
class UnitUnivariateContinuous(Univariate, Continuous):
3645
"""Base class for univariate continuous distributions in [0,1]"""
3746

3847
def __init__(self, *args, **kwargs):
@@ -103,7 +112,7 @@ def get_tau_sd(tau=None, sd=None):
103112
return (tt.as_tensor_variable(tau), tt.as_tensor_variable(sd))
104113

105114

106-
class Uniform(UnivariateContinuous):
115+
class Uniform(Univariate, Continuous):
107116
R"""
108117
Continuous uniform log-likelihood.
109118
@@ -157,7 +166,7 @@ def logp(self, value):
157166
value >= lower, value <= upper)
158167

159168

160-
class Flat(UnivariateContinuous):
169+
class Flat(Univariate, Continuous):
161170
"""
162171
Uninformative log-likelihood that returns 0 regardless of
163172
the passed value.
@@ -178,7 +187,7 @@ def logp(self, value):
178187
return tt.zeros_like(value)
179188

180189

181-
class Normal(UnivariateContinuous):
190+
class Normal(Univariate, Continuous):
182191
R"""
183192
Univariate normal log-likelihood.
184193
@@ -573,7 +582,7 @@ def logp(self, value):
573582
return bound(tt.log(lam) - lam * value, value > 0, lam > 0)
574583

575584

576-
class Laplace(UnivariateContinuous):
585+
class Laplace(Univariate, Continuous):
577586
R"""
578587
Laplace log-likelihood.
579588
@@ -686,7 +695,7 @@ def logp(self, value):
686695
tau > 0)
687696

688697

689-
class StudentT(UnivariateContinuous):
698+
class StudentT(Univariate, Continuous):
690699
r"""
691700
Non-central Student's T log-likelihood.
692701
@@ -812,7 +821,7 @@ def logp(self, value):
812821
value >= m, alpha > 0, m > 0)
813822

814823

815-
class Cauchy(UnivariateContinuous):
824+
class Cauchy(Univariate, Continuous):
816825
R"""
817826
Cauchy log-likelihood.
818827
@@ -1154,7 +1163,7 @@ def logp(self, value):
11541163
value >= 0, alpha > 0, beta > 0)
11551164

11561165

1157-
class Bounded(UnivariateContinuous):
1166+
class Bounded(Univariate, Continuous):
11581167
R"""
11591168
An upper, lower or upper+lower bounded distribution
11601169
@@ -1269,8 +1278,7 @@ def StudentTpos(*args, **kwargs):
12691278

12701279
HalfStudentT = Bound(StudentT, lower=0)
12711280

1272-
1273-
class ExGaussian(UnivariateContinuous):
1281+
class ExGaussian(Univariate, Continuous):
12741282
R"""
12751283
Exponentially modified Gaussian log-likelihood.
12761284
@@ -1356,7 +1364,7 @@ def logp(self, value):
13561364
return bound(lp, sigma > 0., nu > 0.)
13571365

13581366

1359-
class VonMises(UnivariateContinuous):
1367+
class VonMises(Univariate, Continuous):
13601368
R"""
13611369
Univariate VonMises log-likelihood.
13621370

pymc3/distributions/discrete.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from scipy import stats
77

88
from .dist_math import bound, factln, binomln, betaln, logpow
9-
from .distribution import UnivariateDiscrete, draw_values, generate_samples
9+
from .distribution import Univariate, Discrete, draw_values, generate_samples
1010

1111
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'Poisson',
1212
'NegativeBinomial', 'ConstantDist', 'Constant', 'ZeroInflatedPoisson',
1313
'ZeroInflatedNegativeBinomial', 'DiscreteUniform', 'Geometric',
1414
'Categorical']
1515

1616

17-
class Binomial(UnivariateDiscrete):
17+
class Binomial(Univariate, Discrete):
1818
R"""
1919
Binomial log-likelihood.
2020
@@ -63,7 +63,7 @@ def logp(self, value):
6363
0 <= p, p <= 1)
6464

6565

66-
class BetaBinomial(UnivariateDiscrete):
66+
class BetaBinomial(Univariate, Discrete):
6767
R"""
6868
Beta-binomial log-likelihood.
6969
@@ -132,7 +132,7 @@ def logp(self, value):
132132
alpha > 0, beta > 0)
133133

134134

135-
class Bernoulli(UnivariateDiscrete):
135+
class Bernoulli(Univariate, Discrete):
136136
R"""Bernoulli log-likelihood
137137
138138
The Bernoulli distribution describes the probability of successes
@@ -174,7 +174,7 @@ def logp(self, value):
174174
p >= 0, p <= 1)
175175

176176

177-
class Poisson(UnivariateDiscrete):
177+
class Poisson(Univariate, Discrete):
178178
R"""
179179
Poisson log-likelihood.
180180
@@ -225,7 +225,7 @@ def logp(self, value):
225225
0, log_prob)
226226

227227

228-
class NegativeBinomial(UnivariateDiscrete):
228+
class NegativeBinomial(Univariate, Discrete):
229229
R"""
230230
Negative binomial log-likelihood.
231231
@@ -282,7 +282,7 @@ def logp(self, value):
282282
negbinom)
283283

284284

285-
class Geometric(UnivariateDiscrete):
285+
class Geometric(Univariate, Discrete):
286286
R"""
287287
Geometric log-likelihood.
288288
@@ -322,7 +322,7 @@ def logp(self, value):
322322
0 <= p, p <= 1, value >= 1)
323323

324324

325-
class DiscreteUniform(UnivariateDiscrete):
325+
class DiscreteUniform(Univariate, Discrete):
326326
R"""
327327
Discrete uniform distribution.
328328
@@ -373,7 +373,7 @@ def logp(self, value):
373373
lower <= value, value <= upper)
374374

375375

376-
class Categorical(UnivariateDiscrete):
376+
class Categorical(Univariate, Discrete):
377377
R"""
378378
Categorical log-likelihood.
379379
@@ -437,7 +437,7 @@ def logp(self, value):
437437
sumto1)
438438

439439

440-
class ConstantDist(UnivariateDiscrete):
440+
class ConstantDist(Univariate, Discrete):
441441
"""
442442
Constant log-likelihood.
443443
@@ -472,8 +472,7 @@ def ConstantDist(*args, **kwargs):
472472
DeprecationWarning)
473473
return Constant(*args, **kwargs)
474474

475-
476-
class ZeroInflatedPoisson(Discrete):
475+
class ZeroInflatedPoisson(Univariate, Discrete):
477476
R"""
478477
Zero-inflated Poisson log-likelihood.
479478

pymc3/distributions/distribution.py

Lines changed: 15 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -42,32 +42,6 @@ def _as_tensor_shape_variable(var):
4242
return res
4343

4444

45-
def _as_tensor_shape_variable(var):
46-
""" Just a collection of useful shape stuff from
47-
`_infer_ndim_bcast` """
48-
49-
if var is None:
50-
return T.constant([], dtype='int64')
51-
52-
res = var
53-
if isinstance(res, (tuple, list)):
54-
if len(res) == 0:
55-
return T.constant([], dtype='int64')
56-
res = T.as_tensor_variable(res, ndim=1)
57-
58-
else:
59-
if res.ndim != 1:
60-
raise TypeError("shape must be a vector or list of scalar, got\
61-
'%s'" % res)
62-
63-
if (not (res.dtype.startswith('int') or
64-
res.dtype.startswith('uint'))):
65-
66-
raise TypeError('shape must be an integer vector or list',
67-
res.dtype)
68-
return res
69-
70-
7145
class Distribution(object):
7246
"""Statistical distribution"""
7347
def __new__(cls, name, *args, **kwargs):
@@ -186,23 +160,20 @@ def __init__(self, shape_supp, shape_ind, shape_reps, bcast, dtype,
186160
tuple(self.shape_ind) +\
187161
tuple(self.shape_supp)
188162

189-
if testval is None:
190-
if ndim_sum == 0:
191-
testval = tt.constant(0, dtype=dtype)
192-
else:
193-
testval = tt.zeros(self.shape)
194-
195163
self.ndim = tt.get_vector_length(self.shape)
196-
197-
self.testval = testval
198164
self.defaults = defaults
199165
self.transform = transform
166+
167+
if testval is None:
168+
testval = self.get_test_value(defaults=self.defaults)
169+
170+
self.testval = testval
200171
self.type = tt.TensorType(str(dtype), bcast)
201172

202173
def default(self):
203-
return self.get_test_val(self.testval, self.defaults)
174+
return self.get_test_value(self.testval, self.defaults)
204175

205-
def get_test_val(self, val, defaults):
176+
def get_test_value(self, val=None, defaults=None):
206177
if val is None:
207178
for v in defaults:
208179
the_attr = getattr(self, v, None)
@@ -216,9 +187,14 @@ def get_test_val(self, val, defaults):
216187
str(defaults) + " pass testval argument or adjust so value is finite.")
217188

218189
def getattr_value(self, val):
190+
""" Attempts to obtain a non-symbolic value for an attribute
191+
(potentially given in str form)
192+
"""
219193
if isinstance(val, string_types):
220194
val = getattr(self, val)
221195

196+
# Could use Theano's:
197+
# val = theano.gof.op.get_test_value(val)
222198
if isinstance(val, tt.sharedvar.SharedVariable):
223199
return val.get_value()
224200
elif isinstance(val, tt.TensorVariable):
@@ -290,7 +266,7 @@ def __init__(self, logp, ndim_support, ndim, size, bcast, dtype='float64',
290266
self.logp = logp
291267

292268

293-
class UnivariateContinuous(Continuous):
269+
class Univariate(Distribution):
294270

295271
def __init__(self, dist_params, ndim=None, size=None, dtype=None,
296272
bcast=None, *args, **kwargs):
@@ -312,30 +288,15 @@ def __init__(self, dist_params, ndim=None, size=None, dtype=None,
312288
dtype = tt.scal.upcast(*(tt.config.floatX,) + tuple(x.dtype for x in dist_params))
313289

314290
# We just assume
315-
super(UnivariateContinuous, self).__init__(
291+
super(Univariate, self).__init__(
316292
tuple(), tuple(), size, bcast, *args, **kwargs)
317293

318294

319-
class MultivariateContinuous(Continuous):
295+
class Multivariate(Distribution):
320296

321297
pass
322298

323299

324-
325-
class MultivariateDiscrete(Discrete):
326-
327-
pass
328-
329-
330-
class UnivariateDiscrete(Discrete):
331-
332-
def __init__(self, ndim, size, bcast, *args, **kwargs):
333-
self.shape_supp = ()
334-
335-
super(UnivariateDiscrete, self).__init__(
336-
0, ndim, size, bcast, *args, **kwargs)
337-
338-
339300
def draw_values(params, point=None):
340301
"""
341302
Draw (fix) parameter values. Handles a number of cases:

0 commit comments

Comments
 (0)