Skip to content

Commit 3ec9da3

Browse files
bwengalstwiecki
authored andcommitted
fix for multi-dim input bug in Zero and Constant, tidy up docstrings, add tests
1 parent 57726c3 commit 3ec9da3

File tree

2 files changed

+126
-79
lines changed

2 files changed

+126
-79
lines changed

pymc3/gp/mean.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
__all__ = ['Zero', 'Constant']
44

55
class Mean(object):
6-
"""
6+
R"""
77
Base class for mean functions
88
"""
9-
109
def __call__(self, X):
1110
R"""
1211
Evaluate the mean function.
@@ -22,49 +21,52 @@ def __add__(self, other):
2221

2322
def __mul__(self, other):
2423
return Prod(self, other)
25-
24+
2625
class Zero(Mean):
26+
R"""
27+
Zero mean function for Gaussian process.
28+
29+
"""
2730
def __call__(self, X):
28-
return tt.zeros(X.shape, dtype='float32')
29-
31+
return tt.zeros(tt.stack([X.shape[0], ]), dtype='float32')
32+
3033
class Constant(Mean):
31-
"""
34+
R"""
3235
Constant mean function for Gaussian process.
33-
36+
3437
Parameters
3538
----------
3639
c : variable, array or integer
3740
Constant mean value
3841
"""
39-
4042
def __init__(self, c=0):
4143
Mean.__init__(self)
4244
self.c = c
4345

4446
def __call__(self, X):
45-
return tt.ones(X.shape) * self.c
47+
return tt.ones(tt.stack([X.shape[0], ])) * self.c
48+
4649

4750
class Linear(Mean):
48-
51+
R"""
52+
Linear mean function for Gaussian process.
53+
54+
Parameters
55+
----------
56+
coeffs : variables
57+
Linear coefficients
58+
intercept : variable, array or integer
59+
Intercept for linear function (Defaults to zero)
60+
"""
4961
def __init__(self, coeffs, intercept=0):
50-
"""
51-
Linear mean function for Gaussian process.
52-
53-
Parameters
54-
----------
55-
coeffs : variables
56-
Linear coefficients
57-
intercept : variable, array or integer
58-
Intercept for linear function (Defaults to zero)
59-
"""
6062
Mean.__init__(self)
6163
self.b = intercept
6264
self.A = coeffs
63-
65+
6466
def __call__(self, X):
6567
return tt.dot(X, self.A) + self.b
6668

67-
69+
6870
class Add(Mean):
6971
def __init__(self, first_mean, second_mean):
7072
Mean.__init__(self)
@@ -83,4 +85,4 @@ def __init__(self, first_mean, second_mean):
8385

8486
def __call__(self, X):
8587
return tt.mul(self.m1(X), self.m2(X))
86-
88+

0 commit comments

Comments
 (0)