Skip to content

Commit 1cc9962

Browse files
committed
Refactor alpha_recover and inverse_hessian_factors to remove update_mask parameter
- Removed the update_mask variable from alpha_recover and inverse_hessian_factors functions. - Simplified the logic in alpha_recover by directly computing alpha without filtering updates. - Changes should offer speed-ups by reducing reliance on scan functions, and perform vectorised operations.
1 parent f78d5b6 commit 1cc9962

File tree

1 file changed

+11
-40
lines changed

1 file changed

+11
-40
lines changed

pymc_extras/inference/pathfinder/pathfinder.py

+11-40
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,6 @@ def alpha_recover(
259259
position differences, shape (L, N)
260260
z : TensorVariable
261261
gradient differences, shape (L, N)
262-
update_mask : TensorVariable
263-
mask for filtering updates, shape (L,)
264262
265263
Notes
266264
-----
@@ -281,43 +279,28 @@ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
281279
) # fmt:off
282280
return 1.0 / inv_alpha_l
283281

284-
def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable:
285-
return alpha_lm1[-1]
286-
287-
def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable:
288-
return pt.switch(
289-
update_mask_l,
290-
compute_alpha_l(alpha_lm1, s_l, z_l),
291-
return_alpha_lm1(alpha_lm1, s_l, z_l),
292-
)
293-
294282
Lp1, N = x.shape
295283
s = pt.diff(x, axis=0)
296284
z = pt.diff(g, axis=0)
297285
alpha_l_init = pt.ones(N)
298-
sz = (s * z).sum(axis=-1)
299-
# update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
300-
# pt.linalg.norm does not work with JAX!!
301-
update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1))
302286

303287
alpha, _ = pytensor.scan(
304-
fn=scan_body,
288+
fn=compute_alpha_l,
305289
outputs_info=alpha_l_init,
306-
sequences=[update_mask, s, z],
290+
sequences=[s, z],
307291
n_steps=Lp1 - 1,
308292
allow_gc=False,
309293
)
310294

311295
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
312296
# alpha: (L, N), update_mask: (L, N)
313-
return alpha, s, z, update_mask
297+
return alpha, s, z
314298

315299

316300
def inverse_hessian_factors(
317301
alpha: TensorVariable,
318302
s: TensorVariable,
319303
z: TensorVariable,
320-
update_mask: TensorVariable,
321304
J: TensorConstant,
322305
) -> tuple[TensorVariable, TensorVariable]:
323306
"""compute the inverse hessian factors for the BFGS approximation.
@@ -330,8 +313,6 @@ def inverse_hessian_factors(
330313
position differences, shape (L, N)
331314
z : TensorVariable
332315
gradient differences, shape (L, N)
333-
update_mask : TensorVariable
334-
mask for filtering updates, shape (L,)
335316
J : TensorConstant
336317
history size for L-BFGS
337318
@@ -350,30 +331,20 @@ def inverse_hessian_factors(
350331
# NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
351332
# NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented
352333

353-
def get_chi_matrix_1(
354-
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
355-
) -> TensorVariable:
334+
def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
335+
# TODO: vectorize this!
356336
L, N = diff.shape
357337
j_last = pt.as_tensor(J - 1) # since indexing starts at 0
358338

359339
def chi_update(chi_lm1, diff_l) -> TensorVariable:
360340
chi_l = pt.roll(chi_lm1, -1, axis=0)
361341
return pt.set_subtensor(chi_l[j_last], diff_l)
362342

363-
def no_op(chi_lm1, diff_l) -> TensorVariable:
364-
return chi_lm1
365-
366-
def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
367-
return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
368-
369343
chi_init = pt.zeros((J, N))
370344
chi_mat, _ = pytensor.scan(
371-
fn=scan_body,
345+
fn=chi_update,
372346
outputs_info=chi_init,
373-
sequences=[
374-
update_mask,
375-
diff,
376-
],
347+
sequences=[diff],
377348
allow_gc=False,
378349
)
379350

@@ -403,8 +374,8 @@ def get_chi_matrix_2(
403374
return chi_mat
404375

405376
L, N = alpha.shape
406-
S = get_chi_matrix_1(s, update_mask, J)
407-
Z = get_chi_matrix_1(z, update_mask, J)
377+
S = get_chi_matrix_1(s, J)
378+
Z = get_chi_matrix_1(z, J)
408379

409380
# E: (L, J, J)
410381
Ij = pt.eye(J)[None, ...]
@@ -830,8 +801,8 @@ def make_pathfinder_body(
830801
epsilon = pt.constant(epsilon, "epsilon", dtype="float64")
831802
maxcor = pt.constant(maxcor, "maxcor", dtype="int32")
832803

833-
alpha, s, z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
834-
beta, gamma = inverse_hessian_factors(alpha, s, z, update_mask, J=maxcor)
804+
alpha, s, z = alpha_recover(x_full, g_full, epsilon=epsilon)
805+
beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor)
835806

836807
# ignore initial point - x, g: (L, N)
837808
x = x_full[1:]

0 commit comments

Comments
 (0)