Skip to content

Refactor alpha_recover and inverse_hessian_factors to remove update_m… #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions pymc_extras/inference/pathfinder/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ class LBFGSHistoryManager:
initial position
maxiter : int
maximum number of iterations to store
epsilon : float
tolerance for lbfgs update
"""

value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
x0: NDArray[np.float64]
maxiter: int
epsilon: float
x_history: NDArray[np.float64] = field(init=False)
g_history: NDArray[np.float64] = field(init=False)
count: int = field(init=False)
Expand Down Expand Up @@ -85,10 +88,9 @@ def entry_condition_met(self, x, value, grad) -> bool:
s = x - self.x_history[self.count - 1]
z = grad - self.g_history[self.count - 1]
sz = (s * z).sum(axis=-1)
epsilon = 1e-8
update_mask = sz > epsilon * np.sqrt(np.sum(z**2, axis=-1))
update = sz > self.epsilon * np.sqrt(np.sum(z**2, axis=-1))

if update_mask:
if update:
return True
else:
return False
Expand All @@ -105,10 +107,10 @@ class LBFGSStatus(Enum):
CONVERGED = auto()
MAX_ITER_REACHED = auto()
NON_FINITE = auto()
LOW_UPDATE_MASK_RATIO = auto()
LOW_UPDATE_PCT = auto()
# Statuses that lead to Exceptions:
INIT_FAILED = auto()
INIT_FAILED_LOW_UPDATE_MASK = auto()
INIT_FAILED_LOW_UPDATE_PCT = auto()
LBFGS_FAILED = auto()


Expand Down Expand Up @@ -144,17 +146,20 @@ class LBFGS:
gradient tolerance for convergence, defaults to 1e-8
maxls : int, optional
maximum number of line search steps, defaults to 1000
epsilon : float, optional
tolerance for lbfgs update, defaults to 1e-8
"""

def __init__(
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000, epsilon=1e-8
) -> None:
self.value_grad_fn = value_grad_fn
self.maxcor = maxcor
self.maxiter = maxiter
self.ftol = ftol
self.gtol = gtol
self.maxls = maxls
self.epsilon = epsilon

def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
"""minimizes objective function starting from initial position.
Expand All @@ -179,7 +184,7 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
x0 = np.array(x0, dtype=np.float64)

history_manager = LBFGSHistoryManager(
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter, epsilon=self.epsilon
)

result = minimize(
Expand All @@ -199,24 +204,22 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
history = history_manager.get_history()

# warnings and suggestions for LBFGSStatus are displayed at the end
# threshold determining if the number of update mask is low compared to iterations
# threshold determining if the number of lbfgs updates is low compared to iterations
low_update_threshold = 3

logging.warning(f"LBFGS status: {result} \n nit: {result.nit} \n count: {history.count}")

if history.count <= 1: # triggers LBFGSInitFailed
if result.nit < low_update_threshold:
lbfgs_status = LBFGSStatus.INIT_FAILED
else:
lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK
lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT
elif result.status == 1:
# (result.nit > maxiter) or (result.nit > maxls)
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
elif result.status == 2:
# precision loss resulting to inf or nan
lbfgs_status = LBFGSStatus.NON_FINITE
elif history.count < low_update_threshold * result.nit:
lbfgs_status = LBFGSStatus.LOW_UPDATE_MASK_RATIO
elif history.count * low_update_threshold < result.nit:
lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT
else:
lbfgs_status = LBFGSStatus.CONVERGED

Expand Down
91 changes: 26 additions & 65 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def convert_flat_trace_to_idata(


def alpha_recover(
x: TensorVariable, g: TensorVariable, epsilon: TensorVariable
) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
x: TensorVariable, g: TensorVariable
) -> tuple[TensorVariable, TensorVariable, TensorVariable]:
"""compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates.

Parameters
Expand All @@ -247,9 +247,6 @@ def alpha_recover(
position array, shape (L+1, N)
g : TensorVariable
gradient array, shape (L+1, N)
epsilon : float
threshold for filtering updates based on inner product of position
and gradient differences

Returns
-------
Expand All @@ -259,15 +256,13 @@ def alpha_recover(
position differences, shape (L, N)
z : TensorVariable
gradient differences, shape (L, N)
update_mask : TensorVariable
mask for filtering updates, shape (L,)

Notes
-----
shapes: L=batch_size, N=num_params
"""

def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable:
# alpha_lm1: (N,)
# s_l: (N,)
# z_l: (N,)
Expand All @@ -281,43 +276,28 @@ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
) # fmt:off
return 1.0 / inv_alpha_l

def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable:
return alpha_lm1[-1]

def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable:
return pt.switch(
update_mask_l,
compute_alpha_l(alpha_lm1, s_l, z_l),
return_alpha_lm1(alpha_lm1, s_l, z_l),
)

Lp1, N = x.shape
s = pt.diff(x, axis=0)
z = pt.diff(g, axis=0)
alpha_l_init = pt.ones(N)
sz = (s * z).sum(axis=-1)
# update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
# pt.linalg.norm does not work with JAX!!
update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1))

alpha, _ = pytensor.scan(
fn=scan_body,
fn=compute_alpha_l,
outputs_info=alpha_l_init,
sequences=[update_mask, s, z],
sequences=[s, z],
n_steps=Lp1 - 1,
allow_gc=False,
)

# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
# alpha: (L, N), update_mask: (L, N)
return alpha, s, z, update_mask
# alpha: (L, N)
return alpha, s, z


def inverse_hessian_factors(
alpha: TensorVariable,
s: TensorVariable,
z: TensorVariable,
update_mask: TensorVariable,
J: TensorConstant,
) -> tuple[TensorVariable, TensorVariable]:
"""compute the inverse hessian factors for the BFGS approximation.
Expand All @@ -330,8 +310,6 @@ def inverse_hessian_factors(
position differences, shape (L, N)
z : TensorVariable
gradient differences, shape (L, N)
update_mask : TensorVariable
mask for filtering updates, shape (L,)
J : TensorConstant
history size for L-BFGS

Expand All @@ -350,30 +328,19 @@ def inverse_hessian_factors(
# NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
# NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented

def get_chi_matrix_1(
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
) -> TensorVariable:
def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
L, N = diff.shape
j_last = pt.as_tensor(J - 1) # since indexing starts at 0

def chi_update(chi_lm1, diff_l) -> TensorVariable:
def chi_update(diff_l, chi_lm1) -> TensorVariable:
chi_l = pt.roll(chi_lm1, -1, axis=0)
return pt.set_subtensor(chi_l[j_last], diff_l)

def no_op(chi_lm1, diff_l) -> TensorVariable:
return chi_lm1

def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))

chi_init = pt.zeros((J, N))
chi_mat, _ = pytensor.scan(
fn=scan_body,
fn=chi_update,
outputs_info=chi_init,
sequences=[
update_mask,
diff,
],
sequences=[diff],
allow_gc=False,
)

Expand All @@ -382,19 +349,15 @@ def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
# (L, N, J)
return chi_mat

def get_chi_matrix_2(
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
) -> TensorVariable:
def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
L, N = diff.shape

diff_masked = update_mask[:, None] * diff

# diff_padded: (L+J, N)
pad_width = pt.zeros(shape=(2, 2), dtype="int32")
pad_width = pt.set_subtensor(pad_width[0, 0], J)
diff_padded = pt.pad(diff_masked, pad_width, mode="constant")
pad_width = pt.set_subtensor(pad_width[0, 0], J - 1)
diff_padded = pt.pad(diff, pad_width, mode="constant")

index = pt.arange(L)[:, None] + pt.arange(J)[None, :]
index = pt.arange(L)[..., None] + pt.arange(J)[None, ...]
index = index.reshape((L, J))

chi_mat = pt.matrix_transpose(diff_padded[index])
Expand All @@ -403,8 +366,10 @@ def get_chi_matrix_2(
return chi_mat

L, N = alpha.shape
S = get_chi_matrix_1(s, update_mask, J)
Z = get_chi_matrix_1(z, update_mask, J)

# changed to get_chi_matrix_2 after removing update_mask
S = get_chi_matrix_2(s, J)
Z = get_chi_matrix_2(z, J)

# E: (L, J, J)
Ij = pt.eye(J)[None, ...]
Expand Down Expand Up @@ -785,7 +750,6 @@ def make_pathfinder_body(
num_draws: int,
maxcor: int,
num_elbo_draws: int,
epsilon: float,
**compile_kwargs: dict,
) -> Function:
"""
Expand All @@ -801,8 +765,6 @@ def make_pathfinder_body(
The maximum number of iterations for the L-BFGS algorithm.
num_elbo_draws : int
The number of draws for the Evidence Lower Bound (ELBO) estimation.
epsilon : float
The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
compile_kwargs : dict
Additional keyword arguments for the PyTensor compiler.

Expand All @@ -827,11 +789,10 @@ def make_pathfinder_body(

num_draws = pt.constant(num_draws, "num_draws", dtype="int32")
num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32")
epsilon = pt.constant(epsilon, "epsilon", dtype="float64")
maxcor = pt.constant(maxcor, "maxcor", dtype="int32")

alpha, s, z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
beta, gamma = inverse_hessian_factors(alpha, s, z, update_mask, J=maxcor)
alpha, s, z = alpha_recover(x_full, g_full)
beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor)

# ignore initial point - x, g: (L, N)
x = x_full[1:]
Expand Down Expand Up @@ -941,11 +902,11 @@ def neg_logp_dlogp_func(x):
x_base = DictToArrayBijection.map(ip).data

# lbfgs
lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls, epsilon)

# pathfinder body
pathfinder_body_fn = make_pathfinder_body(
logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs
logp_func, num_draws, maxcor, num_elbo_draws, **compile_kwargs
)
rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)

Expand All @@ -957,7 +918,7 @@ def single_pathfinder_fn(random_seed: int) -> PathfinderResult:
x0 = x_base + jitter_value
x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)

if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK}:
if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT}:
raise LBFGSInitFailed(lbfgs_status)
elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
raise LBFGSException()
Expand Down Expand Up @@ -1399,8 +1360,8 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
LBFGSStatus.LOW_UPDATE_MASK_RATIO: "LOW_UPDATE_MASK_RATIO: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK: "INIT_FAILED_LOW_UPDATE_MASK: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
LBFGSStatus.LOW_UPDATE_PCT: "LOW_UPDATE_PCT: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT: "INIT_FAILED_LOW_UPDATE_PCT: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
}

path_status_message = {
Expand Down
17 changes: 8 additions & 9 deletions tests/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter):
)
out, err = capsys.readouterr()
status_pattern = [
r"INIT_FAILED_LOW_UPDATE_MASK\s+\d+",
r"LOW_UPDATE_MASK_RATIO\s+\d+",
r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+",
r"LOW_UPDATE_PCT\s+\d+",
r"LBFGS_FAILED\s+\d+",
r"SUCCESS\s+\d+",
]
Expand All @@ -126,8 +126,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter):
out, err = capsys.readouterr()

status_pattern = [
r"INIT_FAILED_LOW_UPDATE_MASK\s+2",
r"LOW_UPDATE_MASK_RATIO\s+2",
r"INIT_FAILED_LOW_UPDATE_PCT\s+2",
r"LOW_UPDATE_PCT\s+2",
r"LBFGS_FAILED\s+4",
]
for pattern in status_pattern:
Expand Down Expand Up @@ -232,12 +232,11 @@ def test_bfgs_sample():
# get factors
x_full = pt.as_tensor(x_data, dtype="float64")
g_full = pt.as_tensor(g_data, dtype="float64")
epsilon = 1e-11

x = x_full[1:]
g = g_full[1:]
alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon)
beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
alpha, s, z = alpha_recover(x_full, g_full)
beta, gamma = inverse_hessian_factors(alpha, s, z, J)

# sample
phi, logq = bfgs_sample(
Expand All @@ -252,8 +251,8 @@ def test_bfgs_sample():
# check shapes
assert beta.eval().shape == (L, N, 2 * J)
assert gamma.eval().shape == (L, 2 * J, 2 * J)
assert phi.eval().shape == (L, num_samples, N)
assert logq.eval().shape == (L, num_samples)
assert all(phi.shape.eval() == (L, num_samples, N))
assert all(logq.shape.eval() == (L, num_samples))


@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
Expand Down
Loading