Skip to content

Commit e5eacb8

Browse files
committed
Don't recompile Ndarray function on trace slicing
1 parent 24fbbe4 commit e5eacb8

File tree

2 files changed

+38
-19
lines changed

2 files changed

+38
-19
lines changed

pymc/backends/base.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,32 +147,45 @@ class BaseTrace(IBaseTrace):
147147
use different test point that might be with changed variables shapes
148148
"""
149149

150-
def __init__(self, name, model=None, vars=None, test_point=None):
151-
self.name = name
152-
150+
def __init__(
151+
self,
152+
name=None,
153+
model=None,
154+
vars=None,
155+
test_point=None,
156+
*,
157+
fn=None,
158+
var_shapes=None,
159+
var_dtypes=None,
160+
):
153161
model = modelcontext(model)
154-
self.model = model
162+
155163
if vars is None:
156164
vars = model.unobserved_value_vars
157165

158166
unnamed_vars = {var for var in vars if var.name is None}
159167
if unnamed_vars:
160168
raise Exception(f"Can't trace unnamed variables: {unnamed_vars}")
161-
self.vars = vars
162-
self.varnames = [var.name for var in vars]
163-
self.fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore")
169+
170+
if fn is None:
171+
fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore")
164172

165173
# Get variable shapes. Most backends will need this
166174
# information.
167-
if test_point is None:
168-
test_point = model.initial_point()
169-
else:
170-
test_point_ = model.initial_point().copy()
171-
test_point_.update(test_point)
172-
test_point = test_point_
173-
var_values = list(zip(self.varnames, self.fn(test_point)))
174-
self.var_shapes = {var: value.shape for var, value in var_values}
175-
self.var_dtypes = {var: value.dtype for var, value in var_values}
175+
if var_shapes is None or var_dtypes is None:
176+
if test_point is None:
177+
test_point = model.initial_point()
178+
var_values = tuple(zip(vars, fn(**test_point)))
179+
var_shapes = {var.name: value.shape for var, value in var_values}
180+
var_dtypes = {var.name: value.dtype for var, value in var_values}
181+
182+
self.name = name
183+
self.model = model
184+
self.fn = fn
185+
self.vars = vars
186+
self.varnames = [var.name for var in vars]
187+
self.var_shapes = var_shapes
188+
self.var_dtypes = var_dtypes
176189
self.chain = None
177190
self._is_base_setup = False
178191
self.sampler_vars = None

pymc/backends/ndarray.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class NDArray(base.BaseTrace):
4040
`model.unobserved_RVs` is used.
4141
"""
4242

43-
def __init__(self, name=None, model=None, vars=None, test_point=None):
44-
super().__init__(name, model, vars, test_point)
43+
def __init__(self, name=None, model=None, vars=None, test_point=None, **kwargs):
44+
super().__init__(name, model, vars, test_point, **kwargs)
4545
self.draw_idx = 0
4646
self.draws = None
4747
self.samples = {}
@@ -166,7 +166,13 @@ def _slice(self, idx: slice):
166166
# Only the first `draw_idx` value are valid because of preallocation
167167
idx = slice(*idx.indices(len(self)))
168168

169-
sliced = NDArray(model=self.model, vars=self.vars)
169+
sliced = type(self)(
170+
model=self.model,
171+
vars=self.vars,
172+
fn=self.fn,
173+
var_shapes=self.var_shapes,
174+
var_dtypes=self.var_dtypes,
175+
)
170176
sliced.chain = self.chain
171177
sliced.samples = {varname: values[idx] for varname, values in self.samples.items()}
172178
sliced.sampler_vars = self.sampler_vars

0 commit comments

Comments
 (0)