Skip to content

Commit 416e5c2

Browse files
committed
Initial commit of pivoted cholesky algorith from GPyTorch
1 parent 4e6527c commit 416e5c2

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

pymc/gp/pivoted_cholesky.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import numpy as np
2+
import torch
3+
4+
from gpytorch.utils.permutation import apply_permutation
5+
6+
pp = lambda x: np.array2string(x, precision=4, floatmode="fixed")
7+
8+
9+
def pivoted_cholesky_np_gpt(mat: np.matrix, error_tol=1e-6, max_iter=np.Infinity):
10+
"""
11+
mat: numpy matrix of N x N
12+
13+
This is to replicate what is done in GPyTorch verbatim.
14+
"""
15+
n = mat.shape[-1]
16+
max_iter = min(int(max_iter), n)
17+
18+
d = np.array(np.diag(mat))
19+
orig_error = np.max(d)
20+
error = np.linalg.norm(d, 1) / orig_error
21+
pi = np.arange(n)
22+
23+
L = np.zeros((max_iter, n))
24+
25+
m = 0
26+
while m < max_iter and error > error_tol:
27+
permuted_d = d[pi]
28+
max_diag_idx = np.argmax(permuted_d[m:])
29+
max_diag_idx = max_diag_idx + m
30+
max_diag_val = permuted_d[max_diag_idx]
31+
i = max_diag_idx
32+
33+
# swap pi_m and pi_i
34+
pi[m], pi[i] = pi[i], pi[m]
35+
pim = pi[m]
36+
37+
L[m, pim] = np.sqrt(max_diag_val)
38+
39+
if m + 1 < n:
40+
row = apply_permutation(
41+
torch.from_numpy(mat), torch.tensor(pim), right_permutation=None
42+
) # left permutation just swaps row
43+
row = row.numpy().flatten()
44+
pi_i = pi[m + 1 :]
45+
L_m_new = row[pi_i] # length = 9
46+
47+
if m > 0:
48+
L_prev = L[:m, pi_i]
49+
update = L[:m, pim]
50+
prod = update @ L_prev
51+
L_m_new = L_m_new - prod # np.sum(prod, axis=-1)
52+
53+
L_m = L[m, :]
54+
L_m_new = L_m_new / L_m[pim]
55+
L_m[pi_i] = L_m_new
56+
57+
matrix_diag_current = d[pi_i]
58+
d[pi_i] = matrix_diag_current - L_m_new**2
59+
60+
L[m, :] = L_m
61+
error = np.linalg.norm(d[pi_i], 1) / orig_error
62+
m = m + 1
63+
return L, pi

scripts/run_mypy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
pymc/gp/gp.py
4444
pymc/gp/mean.py
4545
pymc/gp/util.py
46+
pymc/gp/pivoted_cholesky.py
4647
pymc/math.py
4748
pymc/ode/__init__.py
4849
pymc/ode/ode.py

0 commit comments

Comments
 (0)