@@ -259,8 +259,6 @@ def alpha_recover(
259
259
position differences, shape (L, N)
260
260
z : TensorVariable
261
261
gradient differences, shape (L, N)
262
- update_mask : TensorVariable
263
- mask for filtering updates, shape (L,)
264
262
265
263
Notes
266
264
-----
@@ -281,43 +279,28 @@ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
281
279
) # fmt:off
282
280
return 1.0 / inv_alpha_l
283
281
284
- def return_alpha_lm1 (alpha_lm1 , s_l , z_l ) -> TensorVariable :
285
- return alpha_lm1 [- 1 ]
286
-
287
- def scan_body (update_mask_l , s_l , z_l , alpha_lm1 ) -> TensorVariable :
288
- return pt .switch (
289
- update_mask_l ,
290
- compute_alpha_l (alpha_lm1 , s_l , z_l ),
291
- return_alpha_lm1 (alpha_lm1 , s_l , z_l ),
292
- )
293
-
294
282
Lp1 , N = x .shape
295
283
s = pt .diff (x , axis = 0 )
296
284
z = pt .diff (g , axis = 0 )
297
285
alpha_l_init = pt .ones (N )
298
- sz = (s * z ).sum (axis = - 1 )
299
- # update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
300
- # pt.linalg.norm does not work with JAX!!
301
- update_mask = sz > epsilon * pt .sqrt (pt .sum (z ** 2 , axis = - 1 ))
302
286
303
287
alpha , _ = pytensor .scan (
304
- fn = scan_body ,
288
+ fn = compute_alpha_l ,
305
289
outputs_info = alpha_l_init ,
306
- sequences = [update_mask , s , z ],
290
+ sequences = [s , z ],
307
291
n_steps = Lp1 - 1 ,
308
292
allow_gc = False ,
309
293
)
310
294
311
295
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
312
296
# alpha: (L, N), update_mask: (L, N)
313
- return alpha , s , z , update_mask
297
+ return alpha , s , z
314
298
315
299
316
300
def inverse_hessian_factors (
317
301
alpha : TensorVariable ,
318
302
s : TensorVariable ,
319
303
z : TensorVariable ,
320
- update_mask : TensorVariable ,
321
304
J : TensorConstant ,
322
305
) -> tuple [TensorVariable , TensorVariable ]:
323
306
"""compute the inverse hessian factors for the BFGS approximation.
@@ -330,8 +313,6 @@ def inverse_hessian_factors(
330
313
position differences, shape (L, N)
331
314
z : TensorVariable
332
315
gradient differences, shape (L, N)
333
- update_mask : TensorVariable
334
- mask for filtering updates, shape (L,)
335
316
J : TensorConstant
336
317
history size for L-BFGS
337
318
@@ -350,30 +331,20 @@ def inverse_hessian_factors(
350
331
# NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
351
332
# NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented
352
333
353
- def get_chi_matrix_1 (
354
- diff : TensorVariable , update_mask : TensorVariable , J : TensorConstant
355
- ) -> TensorVariable :
334
+ def get_chi_matrix_1 (diff : TensorVariable , J : TensorConstant ) -> TensorVariable :
335
+ # TODO: vectorize this!
356
336
L , N = diff .shape
357
337
j_last = pt .as_tensor (J - 1 ) # since indexing starts at 0
358
338
359
339
def chi_update (chi_lm1 , diff_l ) -> TensorVariable :
360
340
chi_l = pt .roll (chi_lm1 , - 1 , axis = 0 )
361
341
return pt .set_subtensor (chi_l [j_last ], diff_l )
362
342
363
- def no_op (chi_lm1 , diff_l ) -> TensorVariable :
364
- return chi_lm1
365
-
366
- def scan_body (update_mask_l , diff_l , chi_lm1 ) -> TensorVariable :
367
- return pt .switch (update_mask_l , chi_update (chi_lm1 , diff_l ), no_op (chi_lm1 , diff_l ))
368
-
369
343
chi_init = pt .zeros ((J , N ))
370
344
chi_mat , _ = pytensor .scan (
371
- fn = scan_body ,
345
+ fn = chi_update ,
372
346
outputs_info = chi_init ,
373
- sequences = [
374
- update_mask ,
375
- diff ,
376
- ],
347
+ sequences = [diff ],
377
348
allow_gc = False ,
378
349
)
379
350
@@ -403,8 +374,8 @@ def get_chi_matrix_2(
403
374
return chi_mat
404
375
405
376
L , N = alpha .shape
406
- S = get_chi_matrix_1 (s , update_mask , J )
407
- Z = get_chi_matrix_1 (z , update_mask , J )
377
+ S = get_chi_matrix_1 (s , J )
378
+ Z = get_chi_matrix_1 (z , J )
408
379
409
380
# E: (L, J, J)
410
381
Ij = pt .eye (J )[None , ...]
@@ -830,8 +801,8 @@ def make_pathfinder_body(
830
801
epsilon = pt .constant (epsilon , "epsilon" , dtype = "float64" )
831
802
maxcor = pt .constant (maxcor , "maxcor" , dtype = "int32" )
832
803
833
- alpha , s , z , update_mask = alpha_recover (x_full , g_full , epsilon = epsilon )
834
- beta , gamma = inverse_hessian_factors (alpha , s , z , update_mask , J = maxcor )
804
+ alpha , s , z = alpha_recover (x_full , g_full , epsilon = epsilon )
805
+ beta , gamma = inverse_hessian_factors (alpha , s , z , J = maxcor )
835
806
836
807
# ignore initial point - x, g: (L, N)
837
808
x = x_full [1 :]
0 commit comments