Skip to content

Commit b589ce8

Browse files
committed
Avoid input copy in Ndarray fn
Also initialize empty trace and set `trust_input=True`
1 parent e5eacb8 commit b589ce8

File tree

5 files changed

+25
-6
lines changed

5 files changed

+25
-6
lines changed

pymc/backends/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
)
3131

3232
import numpy as np
33+
import pytensor
3334

3435
from pymc.backends.report import SamplerReport
3536
from pymc.model import modelcontext
37+
from pymc.pytensorf import compile_pymc
3638
from pymc.util import get_var_name
3739

3840
logger = logging.getLogger(__name__)
@@ -168,7 +170,13 @@ def __init__(
168170
raise Exception(f"Can't trace unnamed variables: {unnamed_vars}")
169171

170172
if fn is None:
171-
fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore")
173+
# borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables
174+
fn = compile_pymc(
175+
inputs=[pytensor.In(v, borrow=True) for v in model.value_vars],
176+
outputs=[pytensor.Out(v, borrow=True) for v in vars],
177+
on_unused_input="ignore",
178+
)
179+
fn.trust_input = True
172180

173181
# Get variable shapes. Most backends will need this
174182
# information.

pymc/backends/ndarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
7676
else: # Otherwise, make array of zeros for each variable.
7777
self.draws = draws
7878
for varname, shape in self.var_shapes.items():
79-
self.samples[varname] = np.zeros((draws, *shape), dtype=self.var_dtypes[varname])
79+
self.samples[varname] = np.empty((draws, *shape), dtype=self.var_dtypes[varname])
8080

8181
if sampler_vars is None:
8282
return
@@ -105,7 +105,7 @@ def record(self, point, sampler_stats=None) -> None:
105105
point: dict
106106
Values mapped to variable names
107107
"""
108-
for varname, value in zip(self.varnames, self.fn(point)):
108+
for varname, value in zip(self.varnames, self.fn(*point.values())):
109109
self.samples[varname][self.draw_idx] = value
110110

111111
if self._stats is not None and sampler_stats is None:

pymc/pytensorf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,12 @@ def compile_pymc(
10241024
"""
10251025
# Create an update mapping of RandomVariable's RNG so that it is automatically
10261026
# updated after every function call
1027-
rng_updates = collect_default_updates(inputs=inputs, outputs=outputs)
1027+
rng_updates = collect_default_updates(
1028+
inputs=[inp.variable if isinstance(inp, pytensor.In) else inp for inp in inputs],
1029+
outputs=[
1030+
out.variable if isinstance(out, pytensor.Out) else out for out in makeiter(outputs)
1031+
],
1032+
)
10281033

10291034
# We always reseed random variables as this provides RNGs with no chances of collision
10301035
if rng_updates:

pymc/variational/opvi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,10 @@ def sample(
15541554
if random_seed is not None:
15551555
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
15561556
samples: dict = self.sample_dict_fn(draws, random_seed=random_seed)
1557-
points = ({name: records[i] for name, records in samples.items()} for i in range(draws))
1557+
points = (
1558+
{name: np.asarray(records[i]) for name, records in samples.items()}
1559+
for i in range(draws)
1560+
)
15581561

15591562
trace = NDArray(
15601563
model=self.model,

tests/backends/fixtures.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,10 @@ class SamplingTestCase(ModelBackendSetupTestCase):
238238
"""
239239

240240
def record_point(self, val):
241-
point = {varname: np.tile(val, value.shape) for varname, value in self.test_point.items()}
241+
point = {
242+
varname: np.tile(val, value.shape).astype(value.dtype)
243+
for varname, value in self.test_point.items()
244+
}
242245
if self.sampler_vars is not None:
243246
stats = [{key: dtype(val) for key, dtype in vars.items()} for vars in self.sampler_vars]
244247
self.strace.record(point=point, sampler_stats=stats)

0 commit comments

Comments
 (0)