Skip to content

Commit 88e2941

Browse files
authored
refactor stein (#2283)
* refactor stein * fix typo
1 parent 3e8a000 commit 88e2941

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

pymc3/variational/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def __init__(self, n_particles=100, jitter=.01, model=None, kernel=test_function
574574
start=start, model=model, local_rv=local_rv, random_seed=random_seed)
575575
super(SVGD, self).__init__(
576576
KSD, histogram,
577-
kernel,
577+
kernel, op_kwargs=dict(temperature=temperature),
578578
model=model, random_seed=random_seed)
579579

580580

pymc3/variational/stein.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,25 @@ def __init__(self, approx, kernel=rbf, input_matrix=None, temperature=1):
2020
@property
2121
@memoize
2222
def grad(self):
23-
t = self.approx.normalizing_constant
24-
Kxy, dxkxy = self.Kxy, self.dxkxy
25-
dlogpdx = self.dlogp # Normalized
2623
n = floatX(self.input_matrix.shape[0])
2724
temperature = self.temperature
28-
svgd_grad = (tt.dot(Kxy, dlogpdx)/temperature + dxkxy/t) / n
29-
return svgd_grad
25+
svgd_grad = (self.density_part_grad / temperature +
26+
self.repulsive_part_grad)
27+
return svgd_grad / n
28+
29+
@property
30+
@memoize
31+
def density_part_grad(self):
32+
Kxy = self.Kxy
33+
dlogpdx = self.dlogp
34+
return tt.dot(Kxy, dlogpdx)
35+
36+
@property
37+
@memoize
38+
def repulsive_part_grad(self):
39+
t = self.approx.normalizing_constant
40+
dxkxy = self.dxkxy
41+
return dxkxy/t
3042

3143
@property
3244
def Kxy(self):

0 commit comments

Comments
 (0)