Skip to content

Commit 307913d

Browse files
authored
Initial commit of pivoted Cholesky algorithm from GPyTorch (#63)
* Initial commit of the linear conjugate gradients * Updated __init__ so that we can import linear_cg * Initial commit of pivoted cholesky * Fixed the name of pivoted cholesky function * Since we are invoking the function as a class attribute, removing the self * Removing linear conjugate gradients from this branch * Also removing linear cg from __init__ * Fixed the import * Adding dependencies for pivoted_cholesky, they are also needed for tests * Added correct package for pytorch * Added try...catch block to ensure users who don't have these packages installed, and don't want to use pivoted Cholesky, can use pymc * Added import checks in the test file as well * removed unused commits * removing pytorch and gpytorch from requirements. * pre-commit wouldn't let me commit print statements * removing the test for now * Raising an ImportError instead of printing * Addressing stylistic comment * formatting fixes * formatting fixes * pre-commit modifications
1 parent 77a3f8f commit 307913d

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# try:
2+
# import gpytorch
3+
# import torch
4+
# except ImportError as e:
5+
# # print(
6+
# # f"Please install Pytorch and GPyTorch to use this pivoted Cholesky implementation. Error {e}"
7+
# # )
8+
# pass
9+
# import numpy as np
10+
#
11+
# import pymc_experimental as pmx
12+
#
13+
#
14+
# def test_match_gpytorch_linearcg_output():
15+
# N = 10
16+
# rank = 5
17+
# np.random.seed(1234) # nans with seed 1234
18+
# K = np.random.randn(N, N)
19+
# K = K @ K.T + N * np.eye(N)
20+
# K_torch = torch.from_numpy(K)
21+
#
22+
# L_gpt = gpytorch.pivoted_cholesky(K_torch, rank=rank, error_tol=1e-3)
23+
# L_np, _ = pmx.utils.pivoted_cholesky(K, max_iter=rank, error_tol=1e-3)
24+
# assert np.allclose(L_gpt, L_np.T)

pymc_experimental/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@
1515

1616
from pymc_experimental.utils import prior, spline
1717
from pymc_experimental.utils.linear_cg import linear_cg
18+
19+
# from pymc_experimental.utils.pivoted_cholesky import pivoted_cholesky
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
try:
2+
import torch
3+
from gpytorch.utils.permutation import apply_permutation
4+
except ImportError as e:
5+
raise ImportError("PyTorch and GPyTorch not found.") from e
6+
7+
import numpy as np
8+
9+
pp = lambda x: np.array2string(x, precision=4, floatmode="fixed")
10+
11+
12+
def pivoted_cholesky(mat: np.matrix, error_tol=1e-6, max_iter=np.inf):
13+
"""
14+
mat: numpy matrix of N x N
15+
16+
This is to replicate what is done in GPyTorch verbatim.
17+
"""
18+
n = mat.shape[-1]
19+
max_iter = min(int(max_iter), n)
20+
21+
d = np.array(np.diag(mat))
22+
orig_error = np.max(d)
23+
error = np.linalg.norm(d, 1) / orig_error
24+
pi = np.arange(n)
25+
26+
L = np.zeros((max_iter, n))
27+
28+
m = 0
29+
while m < max_iter and error > error_tol:
30+
permuted_d = d[pi]
31+
max_diag_idx = np.argmax(permuted_d[m:])
32+
max_diag_idx = max_diag_idx + m
33+
max_diag_val = permuted_d[max_diag_idx]
34+
i = max_diag_idx
35+
36+
# swap pi_m and pi_i
37+
pi[m], pi[i] = pi[i], pi[m]
38+
pim = pi[m]
39+
40+
L[m, pim] = np.sqrt(max_diag_val)
41+
42+
if m + 1 < n:
43+
row = apply_permutation(
44+
torch.from_numpy(mat), torch.tensor(pim), right_permutation=None
45+
) # left permutation just swaps row
46+
row = row.numpy().flatten()
47+
pi_i = pi[m + 1 :]
48+
L_m_new = row[pi_i] # length = 9
49+
50+
if m > 0:
51+
L_prev = L[:m, pi_i]
52+
update = L[:m, pim]
53+
prod = update @ L_prev
54+
L_m_new = L_m_new - prod # np.sum(prod, axis=-1)
55+
56+
L_m = L[m, :]
57+
L_m_new = L_m_new / L_m[pim]
58+
L_m[pi_i] = L_m_new
59+
60+
matrix_diag_current = d[pi_i]
61+
d[pi_i] = matrix_diag_current - L_m_new**2
62+
63+
L[m, :] = L_m
64+
error = np.linalg.norm(d[pi_i], 1) / orig_error
65+
m = m + 1
66+
return L, pi

0 commit comments

Comments
 (0)