Skip to content

Commit 4431749

Browse files
authored
Refactor alpha_recover and inverse_hessian_factors to remove update_m… (#462)
* Refactor alpha_recover and inverse_hessian_factors to remove update_mask parameter - Removed the update_mask variable from alpha_recover and inverse_hessian_factors functions. - Simplified the logic in alpha_recover by directly computing alpha without filtering updates. - Changes should offer speed-ups by reducing reliance on scan functions, and perform vectorised operations. * WIP: tidying up and shorten var names * WIP: modified get_chi_matrix * Updated LBFGS status handling and alpha_recover function - Corrected the condition for LOW_UPDATE_PCT in LBFGS status handling. - Removed update_mask references in alpha_recover and inverse_hessian_factors - Adjusted test cases to reflect changes in status messages and function signatures.
1 parent f08a40f commit 4431749

File tree

3 files changed

+50
-87
lines changed

3 files changed

+50
-87
lines changed

pymc_extras/inference/pathfinder/lbfgs.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ class LBFGSHistoryManager:
3737
initial position
3838
maxiter : int
3939
maximum number of iterations to store
40+
epsilon : float
41+
tolerance for lbfgs update
4042
"""
4143

4244
value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
4345
x0: NDArray[np.float64]
4446
maxiter: int
47+
epsilon: float
4548
x_history: NDArray[np.float64] = field(init=False)
4649
g_history: NDArray[np.float64] = field(init=False)
4750
count: int = field(init=False)
@@ -85,10 +88,9 @@ def entry_condition_met(self, x, value, grad) -> bool:
8588
s = x - self.x_history[self.count - 1]
8689
z = grad - self.g_history[self.count - 1]
8790
sz = (s * z).sum(axis=-1)
88-
epsilon = 1e-8
89-
update_mask = sz > epsilon * np.sqrt(np.sum(z**2, axis=-1))
91+
update = sz > self.epsilon * np.sqrt(np.sum(z**2, axis=-1))
9092

91-
if update_mask:
93+
if update:
9294
return True
9395
else:
9496
return False
@@ -105,10 +107,10 @@ class LBFGSStatus(Enum):
105107
CONVERGED = auto()
106108
MAX_ITER_REACHED = auto()
107109
NON_FINITE = auto()
108-
LOW_UPDATE_MASK_RATIO = auto()
110+
LOW_UPDATE_PCT = auto()
109111
# Statuses that lead to Exceptions:
110112
INIT_FAILED = auto()
111-
INIT_FAILED_LOW_UPDATE_MASK = auto()
113+
INIT_FAILED_LOW_UPDATE_PCT = auto()
112114
LBFGS_FAILED = auto()
113115

114116

@@ -144,17 +146,20 @@ class LBFGS:
144146
gradient tolerance for convergence, defaults to 1e-8
145147
maxls : int, optional
146148
maximum number of line search steps, defaults to 1000
149+
epsilon : float, optional
150+
tolerance for lbfgs update, defaults to 1e-8
147151
"""
148152

149153
def __init__(
150-
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
154+
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000, epsilon=1e-8
151155
) -> None:
152156
self.value_grad_fn = value_grad_fn
153157
self.maxcor = maxcor
154158
self.maxiter = maxiter
155159
self.ftol = ftol
156160
self.gtol = gtol
157161
self.maxls = maxls
162+
self.epsilon = epsilon
158163

159164
def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
160165
"""minimizes objective function starting from initial position.
@@ -179,7 +184,7 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
179184
x0 = np.array(x0, dtype=np.float64)
180185

181186
history_manager = LBFGSHistoryManager(
182-
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
187+
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter, epsilon=self.epsilon
183188
)
184189

185190
result = minimize(
@@ -199,24 +204,22 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
199204
history = history_manager.get_history()
200205

201206
# warnings and suggestions for LBFGSStatus are displayed at the end
202-
# threshold determining if the number of update mask is low compared to iterations
207+
# threshold determining if the number of lbfgs updates is low compared to iterations
203208
low_update_threshold = 3
204209

205-
logging.warning(f"LBFGS status: {result} \n nit: {result.nit} \n count: {history.count}")
206-
207210
if history.count <= 1: # triggers LBFGSInitFailed
208211
if result.nit < low_update_threshold:
209212
lbfgs_status = LBFGSStatus.INIT_FAILED
210213
else:
211-
lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK
214+
lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT
212215
elif result.status == 1:
213216
# (result.nit > maxiter) or (result.nit > maxls)
214217
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
215218
elif result.status == 2:
216219
# precision loss resulting to inf or nan
217220
lbfgs_status = LBFGSStatus.NON_FINITE
218-
elif history.count < low_update_threshold * result.nit:
219-
lbfgs_status = LBFGSStatus.LOW_UPDATE_MASK_RATIO
221+
elif history.count * low_update_threshold < result.nit:
222+
lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT
220223
else:
221224
lbfgs_status = LBFGSStatus.CONVERGED
222225

pymc_extras/inference/pathfinder/pathfinder.py

+26-65
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def convert_flat_trace_to_idata(
237237

238238

239239
def alpha_recover(
240-
x: TensorVariable, g: TensorVariable, epsilon: TensorVariable
241-
) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
240+
x: TensorVariable, g: TensorVariable
241+
) -> tuple[TensorVariable, TensorVariable, TensorVariable]:
242242
"""compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates.
243243
244244
Parameters
@@ -247,9 +247,6 @@ def alpha_recover(
247247
position array, shape (L+1, N)
248248
g : TensorVariable
249249
gradient array, shape (L+1, N)
250-
epsilon : float
251-
threshold for filtering updates based on inner product of position
252-
and gradient differences
253250
254251
Returns
255252
-------
@@ -259,15 +256,13 @@ def alpha_recover(
259256
position differences, shape (L, N)
260257
z : TensorVariable
261258
gradient differences, shape (L, N)
262-
update_mask : TensorVariable
263-
mask for filtering updates, shape (L,)
264259
265260
Notes
266261
-----
267262
shapes: L=batch_size, N=num_params
268263
"""
269264

270-
def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
265+
def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable:
271266
# alpha_lm1: (N,)
272267
# s_l: (N,)
273268
# z_l: (N,)
@@ -281,43 +276,28 @@ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
281276
) # fmt:off
282277
return 1.0 / inv_alpha_l
283278

284-
def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable:
285-
return alpha_lm1[-1]
286-
287-
def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable:
288-
return pt.switch(
289-
update_mask_l,
290-
compute_alpha_l(alpha_lm1, s_l, z_l),
291-
return_alpha_lm1(alpha_lm1, s_l, z_l),
292-
)
293-
294279
Lp1, N = x.shape
295280
s = pt.diff(x, axis=0)
296281
z = pt.diff(g, axis=0)
297282
alpha_l_init = pt.ones(N)
298-
sz = (s * z).sum(axis=-1)
299-
# update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
300-
# pt.linalg.norm does not work with JAX!!
301-
update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1))
302283

303284
alpha, _ = pytensor.scan(
304-
fn=scan_body,
285+
fn=compute_alpha_l,
305286
outputs_info=alpha_l_init,
306-
sequences=[update_mask, s, z],
287+
sequences=[s, z],
307288
n_steps=Lp1 - 1,
308289
allow_gc=False,
309290
)
310291

311292
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
312-
# alpha: (L, N), update_mask: (L, N)
313-
return alpha, s, z, update_mask
293+
# alpha: (L, N)
294+
return alpha, s, z
314295

315296

316297
def inverse_hessian_factors(
317298
alpha: TensorVariable,
318299
s: TensorVariable,
319300
z: TensorVariable,
320-
update_mask: TensorVariable,
321301
J: TensorConstant,
322302
) -> tuple[TensorVariable, TensorVariable]:
323303
"""compute the inverse hessian factors for the BFGS approximation.
@@ -330,8 +310,6 @@ def inverse_hessian_factors(
330310
position differences, shape (L, N)
331311
z : TensorVariable
332312
gradient differences, shape (L, N)
333-
update_mask : TensorVariable
334-
mask for filtering updates, shape (L,)
335313
J : TensorConstant
336314
history size for L-BFGS
337315
@@ -350,30 +328,19 @@ def inverse_hessian_factors(
350328
# NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
351329
# NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented
352330

353-
def get_chi_matrix_1(
354-
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
355-
) -> TensorVariable:
331+
def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
356332
L, N = diff.shape
357333
j_last = pt.as_tensor(J - 1) # since indexing starts at 0
358334

359-
def chi_update(chi_lm1, diff_l) -> TensorVariable:
335+
def chi_update(diff_l, chi_lm1) -> TensorVariable:
360336
chi_l = pt.roll(chi_lm1, -1, axis=0)
361337
return pt.set_subtensor(chi_l[j_last], diff_l)
362338

363-
def no_op(chi_lm1, diff_l) -> TensorVariable:
364-
return chi_lm1
365-
366-
def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
367-
return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
368-
369339
chi_init = pt.zeros((J, N))
370340
chi_mat, _ = pytensor.scan(
371-
fn=scan_body,
341+
fn=chi_update,
372342
outputs_info=chi_init,
373-
sequences=[
374-
update_mask,
375-
diff,
376-
],
343+
sequences=[diff],
377344
allow_gc=False,
378345
)
379346

@@ -382,19 +349,15 @@ def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
382349
# (L, N, J)
383350
return chi_mat
384351

385-
def get_chi_matrix_2(
386-
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
387-
) -> TensorVariable:
352+
def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
388353
L, N = diff.shape
389354

390-
diff_masked = update_mask[:, None] * diff
391-
392355
# diff_padded: (L+J, N)
393356
pad_width = pt.zeros(shape=(2, 2), dtype="int32")
394-
pad_width = pt.set_subtensor(pad_width[0, 0], J)
395-
diff_padded = pt.pad(diff_masked, pad_width, mode="constant")
357+
pad_width = pt.set_subtensor(pad_width[0, 0], J - 1)
358+
diff_padded = pt.pad(diff, pad_width, mode="constant")
396359

397-
index = pt.arange(L)[:, None] + pt.arange(J)[None, :]
360+
index = pt.arange(L)[..., None] + pt.arange(J)[None, ...]
398361
index = index.reshape((L, J))
399362

400363
chi_mat = pt.matrix_transpose(diff_padded[index])
@@ -403,8 +366,10 @@ def get_chi_matrix_2(
403366
return chi_mat
404367

405368
L, N = alpha.shape
406-
S = get_chi_matrix_1(s, update_mask, J)
407-
Z = get_chi_matrix_1(z, update_mask, J)
369+
370+
# changed to get_chi_matrix_2 after removing update_mask
371+
S = get_chi_matrix_2(s, J)
372+
Z = get_chi_matrix_2(z, J)
408373

409374
# E: (L, J, J)
410375
Ij = pt.eye(J)[None, ...]
@@ -785,7 +750,6 @@ def make_pathfinder_body(
785750
num_draws: int,
786751
maxcor: int,
787752
num_elbo_draws: int,
788-
epsilon: float,
789753
**compile_kwargs: dict,
790754
) -> Function:
791755
"""
@@ -801,8 +765,6 @@ def make_pathfinder_body(
801765
The maximum number of iterations for the L-BFGS algorithm.
802766
num_elbo_draws : int
803767
The number of draws for the Evidence Lower Bound (ELBO) estimation.
804-
epsilon : float
805-
The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
806768
compile_kwargs : dict
807769
Additional keyword arguments for the PyTensor compiler.
808770
@@ -827,11 +789,10 @@ def make_pathfinder_body(
827789

828790
num_draws = pt.constant(num_draws, "num_draws", dtype="int32")
829791
num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32")
830-
epsilon = pt.constant(epsilon, "epsilon", dtype="float64")
831792
maxcor = pt.constant(maxcor, "maxcor", dtype="int32")
832793

833-
alpha, s, z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
834-
beta, gamma = inverse_hessian_factors(alpha, s, z, update_mask, J=maxcor)
794+
alpha, s, z = alpha_recover(x_full, g_full)
795+
beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor)
835796

836797
# ignore initial point - x, g: (L, N)
837798
x = x_full[1:]
@@ -941,11 +902,11 @@ def neg_logp_dlogp_func(x):
941902
x_base = DictToArrayBijection.map(ip).data
942903

943904
# lbfgs
944-
lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
905+
lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls, epsilon)
945906

946907
# pathfinder body
947908
pathfinder_body_fn = make_pathfinder_body(
948-
logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs
909+
logp_func, num_draws, maxcor, num_elbo_draws, **compile_kwargs
949910
)
950911
rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)
951912

@@ -957,7 +918,7 @@ def single_pathfinder_fn(random_seed: int) -> PathfinderResult:
957918
x0 = x_base + jitter_value
958919
x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)
959920

960-
if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK}:
921+
if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT}:
961922
raise LBFGSInitFailed(lbfgs_status)
962923
elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
963924
raise LBFGSException()
@@ -1399,8 +1360,8 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
13991360
LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
14001361
LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
14011362
LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1402-
LBFGSStatus.LOW_UPDATE_MASK_RATIO: "LOW_UPDATE_MASK_RATIO: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
1403-
LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK: "INIT_FAILED_LOW_UPDATE_MASK: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
1363+
LBFGSStatus.LOW_UPDATE_PCT: "LOW_UPDATE_PCT: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
1364+
LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT: "INIT_FAILED_LOW_UPDATE_PCT: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
14041365
}
14051366

14061367
path_status_message = {

tests/test_pathfinder.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter):
106106
)
107107
out, err = capsys.readouterr()
108108
status_pattern = [
109-
r"INIT_FAILED_LOW_UPDATE_MASK\s+\d+",
110-
r"LOW_UPDATE_MASK_RATIO\s+\d+",
109+
r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+",
110+
r"LOW_UPDATE_PCT\s+\d+",
111111
r"LBFGS_FAILED\s+\d+",
112112
r"SUCCESS\s+\d+",
113113
]
@@ -126,8 +126,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter):
126126
out, err = capsys.readouterr()
127127

128128
status_pattern = [
129-
r"INIT_FAILED_LOW_UPDATE_MASK\s+2",
130-
r"LOW_UPDATE_MASK_RATIO\s+2",
129+
r"INIT_FAILED_LOW_UPDATE_PCT\s+2",
130+
r"LOW_UPDATE_PCT\s+2",
131131
r"LBFGS_FAILED\s+4",
132132
]
133133
for pattern in status_pattern:
@@ -232,12 +232,11 @@ def test_bfgs_sample():
232232
# get factors
233233
x_full = pt.as_tensor(x_data, dtype="float64")
234234
g_full = pt.as_tensor(g_data, dtype="float64")
235-
epsilon = 1e-11
236235

237236
x = x_full[1:]
238237
g = g_full[1:]
239-
alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon)
240-
beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
238+
alpha, s, z = alpha_recover(x_full, g_full)
239+
beta, gamma = inverse_hessian_factors(alpha, s, z, J)
241240

242241
# sample
243242
phi, logq = bfgs_sample(
@@ -252,8 +251,8 @@ def test_bfgs_sample():
252251
# check shapes
253252
assert beta.eval().shape == (L, N, 2 * J)
254253
assert gamma.eval().shape == (L, 2 * J, 2 * J)
255-
assert phi.eval().shape == (L, num_samples, N)
256-
assert logq.eval().shape == (L, num_samples)
254+
assert all(phi.shape.eval() == (L, num_samples, N))
255+
assert all(logq.shape.eval() == (L, num_samples))
257256

258257

259258
@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])

0 commit comments

Comments
 (0)