Skip to content

Commit 6cdfc30

Browse files
committed
Rename compile_pymc to compile
1 parent a714b24 commit 6cdfc30

File tree

21 files changed

+88
-81
lines changed

21 files changed

+88
-81
lines changed

docs/source/api/pytensorf.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ PyTensor utils
66
.. autosummary::
77
:toctree: generated/
88

9-
compile_pymc
9+
compile
1010
gradient
1111
hessian
1212
hessian_diag
@@ -19,6 +19,4 @@ PyTensor utils
1919
CallableTensor
2020
join_nonshared_inputs
2121
make_shared_replacements
22-
generator
23-
convert_generator_data
2422
convert_data

pymc/backends/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from pymc.backends.report import SamplerReport
3636
from pymc.model import modelcontext
37-
from pymc.pytensorf import compile_pymc
37+
from pymc.pytensorf import compile
3838
from pymc.util import get_var_name
3939

4040
logger = logging.getLogger(__name__)
@@ -171,7 +171,7 @@ def __init__(
171171

172172
if fn is None:
173173
# borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables
174-
fn = compile_pymc(
174+
fn = compile(
175175
inputs=[pytensor.In(v, borrow=True) for v in model.value_vars],
176176
outputs=[pytensor.Out(v, borrow=True) for v in vars],
177177
on_unused_input="ignore",

pymc/func_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,18 @@ def find_constrained_prior(
169169
)
170170

171171
target = (pt.exp(logcdf_lower) - mass_below_lower) ** 2
172-
target_fn = pm.pytensorf.compile_pymc([dist_params], target, allow_input_downcast=True)
172+
target_fn = pm.pytensorf.compile([dist_params], target, allow_input_downcast=True)
173173

174174
constraint = pt.exp(logcdf_upper) - pt.exp(logcdf_lower)
175-
constraint_fn = pm.pytensorf.compile_pymc([dist_params], constraint, allow_input_downcast=True)
175+
constraint_fn = pm.pytensorf.compile([dist_params], constraint, allow_input_downcast=True)
176176

177177
jac: str | Callable
178178
constraint_jac: str | Callable
179179
try:
180180
pytensor_jac = pm.gradient(target, [dist_params])
181-
jac = pm.pytensorf.compile_pymc([dist_params], pytensor_jac, allow_input_downcast=True)
181+
jac = pm.pytensorf.compile([dist_params], pytensor_jac, allow_input_downcast=True)
182182
pytensor_constraint_jac = pm.gradient(constraint, [dist_params])
183-
constraint_jac = pm.pytensorf.compile_pymc(
183+
constraint_jac = pm.pytensorf.compile(
184184
[dist_params], pytensor_constraint_jac, allow_input_downcast=True
185185
)
186186
# when PyMC cannot compute the gradient

pymc/gp/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from scipy.cluster.vq import kmeans
2424

2525
from pymc.model.core import modelcontext
26-
from pymc.pytensorf import compile_pymc
26+
from pymc.pytensorf import compile
2727

2828
JITTER_DEFAULT = 1e-6
2929

@@ -55,7 +55,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
5555
if len(inputs) == 0:
5656
return tuple(v.eval() for v in vars_needed)
5757

58-
fn = compile_pymc(
58+
fn = compile(
5959
inputs,
6060
vars_needed,
6161
allow_input_downcast=True,

pymc/initial_point.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from pymc.logprob.transforms import Transform
2828
from pymc.pytensorf import (
29-
compile_pymc,
29+
compile,
3030
find_rng_nodes,
3131
replace_rng_nodes,
3232
reseed_rngs,
@@ -157,7 +157,7 @@ def make_initial_point_fn(
157157
# Replace original rng shared variables so that we don't mess with them
158158
# when calling the final seeded function
159159
initial_values = replace_rng_nodes(initial_values)
160-
func = compile_pymc(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
160+
func = compile(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
161161

162162
varnames = []
163163
for var in model.free_RVs:

pymc/model/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from pymc.pytensorf import (
5757
PointFunc,
5858
SeedSequenceSeed,
59-
compile_pymc,
59+
compile,
6060
convert_observed_data,
6161
gradient,
6262
hessian,
@@ -253,7 +253,7 @@ def __init__(
253253
)
254254
inputs = grad_vars
255255

256-
self._pytensor_function = compile_pymc(inputs, outputs, givens=givens, **kwargs)
256+
self._pytensor_function = compile(inputs, outputs, givens=givens, **kwargs)
257257
self._raveled_inputs = ravel_inputs
258258

259259
def set_weights(self, values):
@@ -1637,7 +1637,7 @@ def compile_fn(
16371637
inputs = inputvars(outs)
16381638

16391639
with self:
1640-
fn = compile_pymc(
1640+
fn = compile(
16411641
inputs,
16421642
outs,
16431643
allow_input_downcast=True,

pymc/pytensorf.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060

6161
__all__ = [
6262
"CallableTensor",
63+
"compile",
6364
"compile_pymc",
6465
"cont_inputs",
6566
"convert_data",
@@ -981,7 +982,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
981982
return rng_updates
982983

983984

984-
def compile_pymc(
985+
def compile(
985986
inputs,
986987
outputs,
987988
random_seed: SeedSequenceSeed = None,
@@ -990,7 +991,7 @@ def compile_pymc(
990991
) -> Function:
991992
"""Use ``pytensor.function`` with specialized pymc rewrites always enabled.
992993
993-
This function also ensures shared RandomState/Generator used by RandomVariables
994+
This function also ensures shared Generator used by RandomVariables
994995
in the graph are updated across calls, to ensure independent draws.
995996
996997
Parameters
@@ -1061,6 +1062,14 @@ def compile_pymc(
10611062
return pytensor_function
10621063

10631064

1065+
def compile_pymc(*args, **kwargs):
1066+
warnings.warn(
1067+
"compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC",
1068+
FutureWarning,
1069+
)
1070+
return compile(*args, **kwargs)
1071+
1072+
10641073
def constant_fold(
10651074
xs: Sequence[TensorVariable], raise_not_constant: bool = True
10661075
) -> tuple[np.ndarray | Variable, ...]:

pymc/sampling/forward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pymc.backends.base import MultiTrace
5252
from pymc.blocking import PointType
5353
from pymc.model import Model, modelcontext
54-
from pymc.pytensorf import compile_pymc
54+
from pymc.pytensorf import compile
5555
from pymc.util import (
5656
CustomProgress,
5757
RandomState,
@@ -273,7 +273,7 @@ def expand(node):
273273
]
274274

275275
return (
276-
compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
276+
compile(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
277277
set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
278278
)
279279

@@ -329,7 +329,7 @@ def draw(
329329
if random_seed is not None:
330330
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
331331

332-
draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
332+
draw_fn = compile(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
333333

334334
if draws == 1:
335335
return draw_fn()

pymc/smc/kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pymc.initial_point import make_initial_point_expression
3131
from pymc.model import Point, modelcontext
3232
from pymc.pytensorf import (
33-
compile_pymc,
33+
compile,
3434
floatX,
3535
join_nonshared_inputs,
3636
make_shared_replacements,
@@ -636,6 +636,6 @@ def _logp_forw(point, out_vars, in_vars, shared):
636636
out_list, inarray0 = join_nonshared_inputs(
637637
point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared
638638
)
639-
f = compile_pymc([inarray0], out_list[0])
639+
f = compile([inarray0], out_list[0])
640640
f.trust_input = True
641641
return f

pymc/step_methods/metropolis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pymc.initial_point import PointType
3232
from pymc.pytensorf import (
3333
CallableTensor,
34-
compile_pymc,
34+
compile,
3535
floatX,
3636
join_nonshared_inputs,
3737
replace_rng_nodes,
@@ -1241,6 +1241,6 @@ def delta_logp(
12411241

12421242
if compile_kwargs is None:
12431243
compile_kwargs = {}
1244-
f = compile_pymc([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
1244+
f = compile([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
12451245
f.trust_input = True
12461246
return f

pymc/step_methods/slicer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pymc.blocking import RaveledVars, StatsType
2121
from pymc.initial_point import PointType
2222
from pymc.model import modelcontext
23-
from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements
23+
from pymc.pytensorf import compile, join_nonshared_inputs, make_shared_replacements
2424
from pymc.step_methods.arraystep import ArrayStepShared
2525
from pymc.step_methods.compound import Competence, StepMethodState
2626
from pymc.step_methods.state import dataclass_state
@@ -109,7 +109,7 @@ def __init__(
109109
)
110110
if compile_kwargs is None:
111111
compile_kwargs = {}
112-
self.logp = compile_pymc([raveled_inp], logp, **compile_kwargs)
112+
self.logp = compile([raveled_inp], logp, **compile_kwargs)
113113
self.logp.trust_input = True
114114

115115
super().__init__(vars, shared, blocked=blocked, rng=rng)

pymc/testing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
local_check_parameter_to_ninf_switch,
4444
rvs_in_graph,
4545
)
46-
from pymc.pytensorf import compile_pymc, floatX, inputvars
46+
from pymc.pytensorf import compile, floatX, inputvars
4747

4848
# This mode can be used for tests where model compilations takes the bulk of the runtime
4949
# AND where we don't care about posterior numerical or sampling stability (e.g., when
@@ -645,7 +645,7 @@ def check_selfconsistency_discrete_logcdf(
645645
dist_logp_fn = pytensor.function(list(inputvars(dist_logp)), dist_logp)
646646

647647
dist_logcdf = logcdf(dist, value)
648-
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)
648+
dist_logcdf_fn = compile(list(inputvars(dist_logcdf)), dist_logcdf)
649649

650650
domains = paramdomains.copy()
651651
domains["value"] = domain
@@ -721,7 +721,7 @@ def continuous_random_tester(
721721

722722
model, param_vars = build_model(dist, valuedomain, paramdomains, extra_args)
723723
model_dist = change_dist_size(model.named_vars["value"], size, expand=True)
724-
pymc_rand = compile_pymc([], model_dist)
724+
pymc_rand = compile([], model_dist)
725725

726726
domains = paramdomains.copy()
727727
for point in product(domains, n_samples=100):
@@ -760,7 +760,7 @@ def discrete_random_tester(
760760

761761
model, param_vars = build_model(dist, valuedomain, paramdomains)
762762
model_dist = change_dist_size(model.named_vars["value"], size, expand=True)
763-
pymc_rand = compile_pymc([], model_dist)
763+
pymc_rand = compile([], model_dist)
764764

765765
domains = paramdomains.copy()
766766
for point in product(domains, n_samples=100):

pymc/variational/opvi.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
from pymc.model import modelcontext
7373
from pymc.pytensorf import (
7474
SeedSequenceSeed,
75-
compile_pymc,
75+
compile,
7676
find_rng_nodes,
7777
identity,
7878
reseed_rngs,
@@ -388,9 +388,9 @@ def step_function(
388388
)
389389
seed = self.approx.rng.randint(2**30, dtype=np.int64)
390390
if score:
391-
step_fn = compile_pymc([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs)
391+
step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs)
392392
else:
393-
step_fn = compile_pymc([], [], updates=updates, random_seed=seed, **fn_kwargs)
393+
step_fn = compile([], [], updates=updates, random_seed=seed, **fn_kwargs)
394394
return step_fn
395395

396396
@pytensor.config.change_flags(compute_test_value="off")
@@ -420,7 +420,7 @@ def score_function(
420420
more_replacements = {}
421421
loss = self(sc_n_mc, more_replacements=more_replacements)
422422
seed = self.approx.rng.randint(2**30, dtype=np.int64)
423-
return compile_pymc([], loss, random_seed=seed, **fn_kwargs)
423+
return compile([], loss, random_seed=seed, **fn_kwargs)
424424

425425
@pytensor.config.change_flags(compute_test_value="off")
426426
def __call__(self, nmc, **kwargs):
@@ -1517,7 +1517,7 @@ def sample_dict_fn(self):
15171517
names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs]
15181518
sampled = [self.rslice(name) for name in names]
15191519
sampled = self.set_size_and_deterministic(sampled, s, 0)
1520-
sample_fn = compile_pymc([s], sampled)
1520+
sample_fn = compile([s], sampled)
15211521
rng_nodes = find_rng_nodes(sampled)
15221522

15231523
def inner(draws=100, *, random_seed: SeedSequenceSeed = None):

tests/distributions/test_distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
)
4444
from pymc.distributions.shape_utils import change_dist_size
4545
from pymc.logprob.basic import conditional_logp, logp
46-
from pymc.pytensorf import compile_pymc
46+
from pymc.pytensorf import compile
4747
from pymc.testing import (
4848
BaseTestDistributionRandom,
4949
I,
@@ -169,7 +169,7 @@ def update(self, node):
169169
outputs=[dummy_next_rng, dummy_x],
170170
ndim_supp=0,
171171
)(rng)
172-
fn = compile_pymc(inputs=[], outputs=x, random_seed=431)
172+
fn = compile(inputs=[], outputs=x, random_seed=431)
173173
assert fn() != fn()
174174

175175
# Check that custom updates are respected, by using one that's broken
@@ -182,7 +182,7 @@ def update(self, node):
182182
ValueError,
183183
match="No update found for at least one RNG used in SymbolicRandomVariable Op SymbolicRVCustomUpdates",
184184
):
185-
compile_pymc(inputs=[], outputs=x, random_seed=431)
185+
compile(inputs=[], outputs=x, random_seed=431)
186186

187187
def test_recreate_with_different_rng_inputs(self):
188188
"""Test that we can recreate a SymbolicRandomVariable with new RNG inputs.

tests/distributions/test_multivariate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from pymc.logprob.basic import logp
4646
from pymc.logprob.utils import ParameterValueError
4747
from pymc.math import kronecker
48-
from pymc.pytensorf import compile_pymc, floatX
48+
from pymc.pytensorf import compile, floatX
4949
from pymc.sampling.forward import draw
5050
from pymc.testing import (
5151
BaseTestDistributionRandom,
@@ -168,7 +168,7 @@ def stickbreakingweights_logpdf():
168168
_alpha = pt.scalar()
169169
_k = pt.iscalar()
170170
_logp = logp(pm.StickBreakingWeights.dist(_alpha, _k), _value)
171-
core_fn = compile_pymc([_value, _alpha, _k], _logp)
171+
core_fn = compile([_value, _alpha, _k], _logp)
172172

173173
return np.vectorize(core_fn, signature="(n),(),()->()")
174174

tests/distributions/test_shape_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def test_size_from_dims_rng_update(self):
326326
with pm.Model(coords={"x_dim": range(2)}):
327327
x = pm.Normal("x", dims=("x_dim",))
328328

329-
fn = pm.pytensorf.compile_pymc([], x)
329+
fn = pm.pytensorf.compile([], x)
330330
# Check that both function outputs (rng and draws) come from the same Apply node
331331
assert fn.maker.fgraph.outputs[0].owner is fn.maker.fgraph.outputs[1].owner
332332

@@ -341,7 +341,7 @@ def test_size_from_observed_rng_update(self):
341341
with pm.Model():
342342
x = pm.Normal("x", observed=[0, 1])
343343

344-
fn = pm.pytensorf.compile_pymc([], x)
344+
fn = pm.pytensorf.compile([], x)
345345
# Check that both function outputs (rng and draws) come from the same Apply node
346346
assert fn.maker.fgraph.outputs[0].owner is fn.maker.fgraph.outputs[1].owner
347347

0 commit comments

Comments
 (0)