|
57 | 57 |
|
58 | 58 | import pymc as pm
|
59 | 59 |
|
60 |
| -from pymc.aesaraf import at_rng, identity, rvs_to_value_vars |
| 60 | +from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars |
61 | 61 | from pymc.backends import NDArray
|
62 | 62 | from pymc.blocking import DictToArrayBijection
|
63 | 63 | from pymc.initial_point import make_initial_point_fn
|
@@ -363,9 +363,9 @@ def step_function(
|
363 | 363 | total_grad_norm_constraint=total_grad_norm_constraint,
|
364 | 364 | )
|
365 | 365 | if score:
|
366 |
| - step_fn = aesara.function([], updates.loss, updates=updates, **fn_kwargs) |
| 366 | + step_fn = compile_pymc([], updates.loss, updates=updates, **fn_kwargs) |
367 | 367 | else:
|
368 |
| - step_fn = aesara.function([], None, updates=updates, **fn_kwargs) |
| 368 | + step_fn = compile_pymc([], None, updates=updates, **fn_kwargs) |
369 | 369 | return step_fn
|
370 | 370 |
|
371 | 371 | @aesara.config.change_flags(compute_test_value="off")
|
@@ -394,7 +394,7 @@ def score_function(
|
394 | 394 | if more_replacements is None:
|
395 | 395 | more_replacements = {}
|
396 | 396 | loss = self(sc_n_mc, more_replacements=more_replacements)
|
397 |
| - return aesara.function([], loss, **fn_kwargs) |
| 397 | + return compile_pymc([], loss, **fn_kwargs) |
398 | 398 |
|
399 | 399 | @aesara.config.change_flags(compute_test_value="off")
|
400 | 400 | def __call__(self, nmc, **kwargs):
|
@@ -1637,7 +1637,7 @@ def sample_dict_fn(self):
|
1637 | 1637 | names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs]
|
1638 | 1638 | sampled = [self.rslice(name) for name in names]
|
1639 | 1639 | sampled = self.set_size_and_deterministic(sampled, s, 0)
|
1640 |
| - sample_fn = aesara.function([s], sampled) |
| 1640 | + sample_fn = compile_pymc([s], sampled) |
1641 | 1641 |
|
1642 | 1642 | def inner(draws=100):
|
1643 | 1643 | _samples = sample_fn(draws)
|
|
0 commit comments