Skip to content

Commit 1cb1418

Browse files
committed
Use compile_pymc instead of aesara.function
1 parent 32006cd commit 1cb1418

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pymc/variational/opvi.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
import pymc as pm
5959

60-
from pymc.aesaraf import at_rng, identity, rvs_to_value_vars
60+
from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars
6161
from pymc.backends import NDArray
6262
from pymc.blocking import DictToArrayBijection
6363
from pymc.initial_point import make_initial_point_fn
@@ -363,9 +363,9 @@ def step_function(
363363
total_grad_norm_constraint=total_grad_norm_constraint,
364364
)
365365
if score:
366-
step_fn = aesara.function([], updates.loss, updates=updates, **fn_kwargs)
366+
step_fn = compile_pymc([], updates.loss, updates=updates, **fn_kwargs)
367367
else:
368-
step_fn = aesara.function([], None, updates=updates, **fn_kwargs)
368+
step_fn = compile_pymc([], None, updates=updates, **fn_kwargs)
369369
return step_fn
370370

371371
@aesara.config.change_flags(compute_test_value="off")
@@ -394,7 +394,7 @@ def score_function(
394394
if more_replacements is None:
395395
more_replacements = {}
396396
loss = self(sc_n_mc, more_replacements=more_replacements)
397-
return aesara.function([], loss, **fn_kwargs)
397+
return compile_pymc([], loss, **fn_kwargs)
398398

399399
@aesara.config.change_flags(compute_test_value="off")
400400
def __call__(self, nmc, **kwargs):
@@ -1637,7 +1637,7 @@ def sample_dict_fn(self):
16371637
names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs]
16381638
sampled = [self.rslice(name) for name in names]
16391639
sampled = self.set_size_and_deterministic(sampled, s, 0)
1640-
sample_fn = aesara.function([s], sampled)
1640+
sample_fn = compile_pymc([s], sampled)
16411641

16421642
def inner(draws=100):
16431643
_samples = sample_fn(draws)

0 commit comments

Comments
 (0)