61
61
62
62
from pytensor .graph .basic import Variable
63
63
from pytensor .graph .replace import graph_replace
64
+ from pytensor .scalar .basic import identity as scalar_identity
65
+ from pytensor .tensor .elemwise import Elemwise
64
66
from pytensor .tensor .shape import unbroadcast
65
67
66
68
import pymc as pm
74
76
SeedSequenceSeed ,
75
77
compile ,
76
78
find_rng_nodes ,
77
- identity ,
78
79
reseed_rngs ,
79
80
)
80
81
from pymc .util import (
@@ -332,6 +333,7 @@ def step_function(
332
333
more_replacements = None ,
333
334
total_grad_norm_constraint = None ,
334
335
score = False ,
336
+ compile_kwargs = None ,
335
337
fn_kwargs = None ,
336
338
):
337
339
R"""Step function that should be called on each optimization step.
@@ -362,17 +364,30 @@ def step_function(
362
364
Bounds gradient norm, prevents exploding gradient problem
363
365
score: `bool`
364
366
calculate loss on each step? Defaults to False for speed
365
- fn_kwargs : `dict`
367
+ compile_kwargs : `dict`
366
368
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
369
+ fn_kwargs: dict
370
+ arbitrary kwargs passed to `pytensor.function`
371
+
372
+ .. warning:: `fn_kwargs` is deprecated and will be removed in future versions
373
+
367
374
more_replacements: `dict`
368
375
Apply custom replacements before calculating gradients
369
376
370
377
Returns
371
378
-------
372
379
`pytensor.function`
373
380
"""
374
- if fn_kwargs is None :
375
- fn_kwargs = {}
381
+ if fn_kwargs is not None :
382
+ warnings .warn (
383
+ "`fn_kwargs` is deprecated and will be removed in future versions. Use "
384
+ "`compile_kwargs` instead." ,
385
+ DeprecationWarning ,
386
+ )
387
+ compile_kwargs = fn_kwargs
388
+
389
+ if compile_kwargs is None :
390
+ compile_kwargs = {}
376
391
if score and not self .op .returns_loss :
377
392
raise NotImplementedError (f"{ self .op } does not have loss" )
378
393
updates = self .updates (
@@ -388,14 +403,14 @@ def step_function(
388
403
)
389
404
seed = self .approx .rng .randint (2 ** 30 , dtype = np .int64 )
390
405
if score :
391
- step_fn = compile ([], updates .loss , updates = updates , random_seed = seed , ** fn_kwargs )
406
+ step_fn = compile ([], updates .loss , updates = updates , random_seed = seed , ** compile_kwargs )
392
407
else :
393
- step_fn = compile ([], [], updates = updates , random_seed = seed , ** fn_kwargs )
408
+ step_fn = compile ([], [], updates = updates , random_seed = seed , ** compile_kwargs )
394
409
return step_fn
395
410
396
411
@pytensor .config .change_flags (compute_test_value = "off" )
397
412
def score_function (
398
- self , sc_n_mc = None , more_replacements = None , fn_kwargs = None
413
+ self , sc_n_mc = None , more_replacements = None , compile_kwargs = None , fn_kwargs = None
399
414
): # pragma: no cover
400
415
R"""Compile scoring function that operates which takes no inputs and returns Loss.
401
416
@@ -405,22 +420,34 @@ def score_function(
405
420
number of scoring MC samples
406
421
more_replacements:
407
422
Apply custom replacements before compiling a function
423
+ compile_kwargs: `dict`
424
+ arbitrary kwargs passed to `pytensor.function`
408
425
fn_kwargs: `dict`
409
426
arbitrary kwargs passed to `pytensor.function`
410
427
428
+ .. warning:: `fn_kwargs` is deprecated and will be removed in future versions
429
+
411
430
Returns
412
431
-------
413
432
pytensor.function
414
433
"""
415
- if fn_kwargs is None :
416
- fn_kwargs = {}
434
+ if fn_kwargs is not None :
435
+ warnings .warn (
436
+ "`fn_kwargs` is deprecated and will be removed in future versions. Use "
437
+ "`compile_kwargs` instead" ,
438
+ DeprecationWarning ,
439
+ )
440
+ compile_kwargs = fn_kwargs
441
+
442
+ if compile_kwargs is None :
443
+ compile_kwargs = {}
417
444
if not self .op .returns_loss :
418
445
raise NotImplementedError (f"{ self .op } does not have loss" )
419
446
if more_replacements is None :
420
447
more_replacements = {}
421
448
loss = self (sc_n_mc , more_replacements = more_replacements )
422
449
seed = self .approx .rng .randint (2 ** 30 , dtype = np .int64 )
423
- return compile ([], loss , random_seed = seed , ** fn_kwargs )
450
+ return compile ([], loss , random_seed = seed , ** compile_kwargs )
424
451
425
452
@pytensor .config .change_flags (compute_test_value = "off" )
426
453
def __call__ (self , nmc , ** kwargs ):
@@ -451,7 +478,7 @@ class Operator:
451
478
require_logq = True
452
479
objective_class = ObjectiveFunction
453
480
supports_aevb = property (lambda self : not self .approx .any_histograms )
454
- T = identity
481
+ T = Elemwise ( scalar_identity )
455
482
456
483
def __init__ (self , approx ):
457
484
self .approx = approx
0 commit comments