3
3
__all__ = ['Zero' , 'Constant' ]
4
4
5
5
class Mean (object ):
6
- """
6
+ R """
7
7
Base class for mean functions
8
8
"""
9
-
10
9
def __call__ (self , X ):
11
10
R"""
12
11
Evaluate the mean function.
@@ -22,49 +21,52 @@ def __add__(self, other):
22
21
23
22
def __mul__ (self , other ):
24
23
return Prod (self , other )
25
-
24
+
26
25
class Zero (Mean ):
26
+ R"""
27
+ Zero mean function for Gaussian process.
28
+
29
+ """
27
30
def __call__ (self , X ):
28
- return tt .zeros (X .shape , dtype = 'float32' )
29
-
31
+ return tt .zeros (tt . stack ([ X .shape [ 0 ], ]) , dtype = 'float32' )
32
+
30
33
class Constant (Mean ):
31
- """
34
+ R """
32
35
Constant mean function for Gaussian process.
33
-
36
+
34
37
Parameters
35
38
----------
36
39
c : variable, array or integer
37
40
Constant mean value
38
41
"""
39
-
40
42
def __init__ (self , c = 0 ):
41
43
Mean .__init__ (self )
42
44
self .c = c
43
45
44
46
def __call__ (self , X ):
45
- return tt .ones (X .shape ) * self .c
47
+ return tt .ones (tt .stack ([X .shape [0 ], ])) * self .c
48
+
46
49
47
50
class Linear (Mean ):
48
-
51
+ R"""
52
+ Linear mean function for Gaussian process.
53
+
54
+ Parameters
55
+ ----------
56
+ coeffs : variables
57
+ Linear coefficients
58
+ intercept : variable, array or integer
59
+ Intercept for linear function (Defaults to zero)
60
+ """
49
61
def __init__ (self , coeffs , intercept = 0 ):
50
- """
51
- Linear mean function for Gaussian process.
52
-
53
- Parameters
54
- ----------
55
- coeffs : variables
56
- Linear coefficients
57
- intercept : variable, array or integer
58
- Intercept for linear function (Defaults to zero)
59
- """
60
62
Mean .__init__ (self )
61
63
self .b = intercept
62
64
self .A = coeffs
63
-
65
+
64
66
def __call__ (self , X ):
65
67
return tt .dot (X , self .A ) + self .b
66
68
67
-
69
+
68
70
class Add (Mean ):
69
71
def __init__ (self , first_mean , second_mean ):
70
72
Mean .__init__ (self )
@@ -83,4 +85,4 @@ def __init__(self, first_mean, second_mean):
83
85
84
86
def __call__ (self , X ):
85
87
return tt .mul (self .m1 (X ), self .m2 (X ))
86
-
88
+
0 commit comments