Skip to content

Commit 52a00e9

Browse files
ferrinetaku-y
authored andcommitted
pretty rbf
1 parent 5e1ea61 commit 52a00e9

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

pymc3/variational/test_functions.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from theano import tensor as tt
1+
from theano import theano, tensor as tt
22
from .opvi import TestFunction
33

44

@@ -21,9 +21,9 @@ class Kernel(TestFunction):
2121
class RBF(Kernel):
2222
def __call__(self, X):
2323
XY = X.dot(X.T)
24-
x2 = tt.reshape(tt.sum(tt.square(X), axis=1), (X.shape[0], 1))
24+
x2 = tt.sum(X ** 2, axis=1).dimshuffle(0, 'x')
2525
X2e = tt.repeat(x2, X.shape[0], axis=1)
26-
H = tt.sub(tt.add(X2e, X2e.T), 2 * XY)
26+
H = X2e + X2e.T - 2. * XY
2727

2828
V = tt.sort(H.flatten())
2929
length = V.shape[0]
@@ -34,9 +34,12 @@ def __call__(self, X):
3434
# if odd vector
3535
V[length // 2])
3636

37-
h = 0.5 * m / tt.log(X.shape[0].astype('float32') + 1.0)
37+
h = .5 * m / tt.log(tt.cast(H.shape[0] + 1., theano.config.floatX))
3838

39+
# RBF
3940
Kxy = tt.exp(-H / h / 2.0)
41+
42+
# Derivative
4043
dxkxy = -tt.dot(Kxy, X)
4144
sumkxy = tt.sum(Kxy, axis=1).dimshuffle(0, 'x')
4245
dxkxy = tt.add(dxkxy, tt.mul(X, sumkxy)) / h

0 commit comments

Comments
 (0)