Skip to content

Commit 52f7fe1

Browse files
Foward compile_kwargs to ADVI when init = "advi+..." (#7640)
1 parent 82716fb commit 52f7fe1

File tree

4 files changed

+56
-42
lines changed

4 files changed

+56
-42
lines changed

pymc/pytensorf.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import pytensor.tensor as pt
2323
import scipy.sparse as sps
2424

25-
from pytensor import scalar
2625
from pytensor.compile import Function, Mode, get_mode
2726
from pytensor.compile.builders import OpFromGraph
2827
from pytensor.gradient import grad
@@ -415,31 +414,6 @@ def hessian_diag(f, vars=None, negate_output=True):
415414
return empty_gradient
416415

417416

418-
class IdentityOp(scalar.UnaryScalarOp):
419-
@staticmethod
420-
def st_impl(x):
421-
return x
422-
423-
def impl(self, x):
424-
return x
425-
426-
def grad(self, inp, grads):
427-
return grads
428-
429-
def c_code(self, node, name, inp, out, sub):
430-
return f"{out[0]} = {inp[0]};"
431-
432-
def __eq__(self, other):
433-
return isinstance(self, type(other))
434-
435-
def __hash__(self):
436-
return hash(type(self))
437-
438-
439-
scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity")
440-
identity = Elemwise(scalar_identity, name="identity")
441-
442-
443417
def make_shared_replacements(point, vars, model):
444418
"""
445419
Make shared replacements for all *other* variables than the ones passed.

pymc/sampling/mcmc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,7 @@ def init_nuts(
15531553
callbacks=cb,
15541554
progressbar=progressbar,
15551555
obj_optimizer=pm.adagrad_window,
1556+
compile_kwargs=compile_kwargs,
15561557
)
15571558
approx_sample = approx.sample(
15581559
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
@@ -1566,6 +1567,7 @@ def init_nuts(
15661567
potential = quadpotential.QuadPotentialDiagAdapt(
15671568
n, mean, cov, weight, rng=random_seed_list[0]
15681569
)
1570+
15691571
elif init == "advi":
15701572
approx = pm.fit(
15711573
random_seed=random_seed_list[0],
@@ -1575,6 +1577,7 @@ def init_nuts(
15751577
callbacks=cb,
15761578
progressbar=progressbar,
15771579
obj_optimizer=pm.adagrad_window,
1580+
compile_kwargs=compile_kwargs,
15781581
)
15791582
approx_sample = approx.sample(
15801583
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
@@ -1592,6 +1595,7 @@ def init_nuts(
15921595
callbacks=cb,
15931596
progressbar=progressbar,
15941597
obj_optimizer=pm.adagrad_window,
1598+
compile_kwargs=compile_kwargs,
15951599
)
15961600
approx_sample = approx.sample(
15971601
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False

pymc/variational/inference.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,18 @@ def _maybe_score(self, score):
8282

8383
def run_profiling(self, n=1000, score=None, **kwargs):
8484
score = self._maybe_score(score)
85-
fn_kwargs = kwargs.pop("fn_kwargs", {})
86-
fn_kwargs["profile"] = True
87-
step_func = self.objective.step_function(score=score, fn_kwargs=fn_kwargs, **kwargs)
85+
if "fn_kwargs" in kwargs:
86+
warnings.warn(
87+
"fn_kwargs is deprecated, please use compile_kwargs instead", DeprecationWarning
88+
)
89+
compile_kwargs = kwargs.pop("fn_kwargs")
90+
else:
91+
compile_kwargs = kwargs.pop("compile_kwargs", {})
92+
93+
compile_kwargs["profile"] = True
94+
step_func = self.objective.step_function(
95+
score=score, compile_kwargs=compile_kwargs, **kwargs
96+
)
8897
try:
8998
for _ in track(range(n)):
9099
step_func()
@@ -134,7 +143,7 @@ def fit(
134143
Add custom updates to resulting updates
135144
total_grad_norm_constraint: `float`
136145
Bounds gradient norm, prevents exploding gradient problem
137-
fn_kwargs: `dict`
146+
compile_kwargs: `dict`
138147
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
139148
more_replacements: `dict`
140149
Apply custom replacements before calculating gradients
@@ -729,7 +738,7 @@ def fit(
729738
Add custom updates to resulting updates
730739
total_grad_norm_constraint: `float`
731740
Bounds gradient norm, prevents exploding gradient problem
732-
fn_kwargs: `dict`
741+
compile_kwargs: `dict`
733742
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
734743
more_replacements: `dict`
735744
Apply custom replacements before calculating gradients

pymc/variational/opvi.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161

6262
from pytensor.graph.basic import Variable
6363
from pytensor.graph.replace import graph_replace
64+
from pytensor.scalar.basic import identity as scalar_identity
65+
from pytensor.tensor.elemwise import Elemwise
6466
from pytensor.tensor.shape import unbroadcast
6567

6668
import pymc as pm
@@ -74,7 +76,6 @@
7476
SeedSequenceSeed,
7577
compile,
7678
find_rng_nodes,
77-
identity,
7879
reseed_rngs,
7980
)
8081
from pymc.util import (
@@ -332,6 +333,7 @@ def step_function(
332333
more_replacements=None,
333334
total_grad_norm_constraint=None,
334335
score=False,
336+
compile_kwargs=None,
335337
fn_kwargs=None,
336338
):
337339
R"""Step function that should be called on each optimization step.
@@ -362,17 +364,30 @@ def step_function(
362364
Bounds gradient norm, prevents exploding gradient problem
363365
score: `bool`
364366
calculate loss on each step? Defaults to False for speed
365-
fn_kwargs: `dict`
367+
compile_kwargs: `dict`
366368
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
369+
fn_kwargs: dict
370+
arbitrary kwargs passed to `pytensor.function`
371+
372+
.. warning:: `fn_kwargs` is deprecated and will be removed in future versions
373+
367374
more_replacements: `dict`
368375
Apply custom replacements before calculating gradients
369376
370377
Returns
371378
-------
372379
`pytensor.function`
373380
"""
374-
if fn_kwargs is None:
375-
fn_kwargs = {}
381+
if fn_kwargs is not None:
382+
warnings.warn(
383+
"`fn_kwargs` is deprecated and will be removed in future versions. Use "
384+
"`compile_kwargs` instead.",
385+
DeprecationWarning,
386+
)
387+
compile_kwargs = fn_kwargs
388+
389+
if compile_kwargs is None:
390+
compile_kwargs = {}
376391
if score and not self.op.returns_loss:
377392
raise NotImplementedError(f"{self.op} does not have loss")
378393
updates = self.updates(
@@ -388,14 +403,14 @@ def step_function(
388403
)
389404
seed = self.approx.rng.randint(2**30, dtype=np.int64)
390405
if score:
391-
step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs)
406+
step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs)
392407
else:
393-
step_fn = compile([], [], updates=updates, random_seed=seed, **fn_kwargs)
408+
step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs)
394409
return step_fn
395410

396411
@pytensor.config.change_flags(compute_test_value="off")
397412
def score_function(
398-
self, sc_n_mc=None, more_replacements=None, fn_kwargs=None
413+
self, sc_n_mc=None, more_replacements=None, compile_kwargs=None, fn_kwargs=None
399414
): # pragma: no cover
400415
R"""Compile scoring function that operates which takes no inputs and returns Loss.
401416
@@ -405,22 +420,34 @@ def score_function(
405420
number of scoring MC samples
406421
more_replacements:
407422
Apply custom replacements before compiling a function
423+
compile_kwargs: `dict`
424+
arbitrary kwargs passed to `pytensor.function`
408425
fn_kwargs: `dict`
409426
arbitrary kwargs passed to `pytensor.function`
410427
428+
.. warning:: `fn_kwargs` is deprecated and will be removed in future versions
429+
411430
Returns
412431
-------
413432
pytensor.function
414433
"""
415-
if fn_kwargs is None:
416-
fn_kwargs = {}
434+
if fn_kwargs is not None:
435+
warnings.warn(
436+
"`fn_kwargs` is deprecated and will be removed in future versions. Use "
437+
"`compile_kwargs` instead",
438+
DeprecationWarning,
439+
)
440+
compile_kwargs = fn_kwargs
441+
442+
if compile_kwargs is None:
443+
compile_kwargs = {}
417444
if not self.op.returns_loss:
418445
raise NotImplementedError(f"{self.op} does not have loss")
419446
if more_replacements is None:
420447
more_replacements = {}
421448
loss = self(sc_n_mc, more_replacements=more_replacements)
422449
seed = self.approx.rng.randint(2**30, dtype=np.int64)
423-
return compile([], loss, random_seed=seed, **fn_kwargs)
450+
return compile([], loss, random_seed=seed, **compile_kwargs)
424451

425452
@pytensor.config.change_flags(compute_test_value="off")
426453
def __call__(self, nmc, **kwargs):
@@ -451,7 +478,7 @@ class Operator:
451478
require_logq = True
452479
objective_class = ObjectiveFunction
453480
supports_aevb = property(lambda self: not self.approx.any_histograms)
454-
T = identity
481+
T = Elemwise(scalar_identity)
455482

456483
def __init__(self, approx):
457484
self.approx = approx

0 commit comments

Comments
 (0)