Skip to content

Commit acf5175

Browse files
committed
ValueGradFunction inner function now accepts a raveled input
1 parent f979bd6 commit acf5175

File tree

10 files changed

+187
-100
lines changed

10 files changed

+187
-100
lines changed

pymc/model/core.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
gradient,
6262
hessian,
6363
inputvars,
64+
join_nonshared_inputs,
6465
rewrite_pregrad,
6566
)
6667
from pymc.util import (
@@ -172,6 +173,9 @@ def __init__(
172173
dtype=None,
173174
casting="no",
174175
compute_grads=True,
176+
model=None,
177+
initial_point=None,
178+
ravel_inputs: bool | None = None,
175179
**kwargs,
176180
):
177181
if extra_vars_and_values is None:
@@ -219,9 +223,7 @@ def __init__(
219223
givens = []
220224
self._extra_vars_shared = {}
221225
for var, value in extra_vars_and_values.items():
222-
shared = pytensor.shared(
223-
value, var.name + "_shared__", shape=[1 if s == 1 else None for s in value.shape]
224-
)
226+
shared = pytensor.shared(value, var.name + "_shared__", shape=value.shape)
225227
self._extra_vars_shared[var.name] = shared
226228
givens.append((var, shared))
227229

@@ -231,13 +233,28 @@ def __init__(
231233
grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore")
232234
for grad_wrt, var in zip(grads, grad_vars):
233235
grad_wrt.name = f"{var.name}_grad"
234-
outputs = [cost, *grads]
236+
grads = pt.join(0, *[pt.atleast_1d(grad.ravel()) for grad in grads])
237+
outputs = [cost, grads]
235238
else:
236239
outputs = [cost]
237240

238-
inputs = grad_vars
241+
if ravel_inputs:
242+
if initial_point is None:
243+
initial_point = modelcontext(model).initial_point()
244+
outputs, raveled_grad_vars = join_nonshared_inputs(
245+
point=initial_point, inputs=grad_vars, outputs=outputs, make_inputs_shared=False
246+
)
247+
inputs = [raveled_grad_vars]
248+
else:
249+
if ravel_inputs is None:
250+
warnings.warn(
251+
"ValueGradFunction will become a function of raveled inputs.\n"
252+
"Specify `ravel_inputs` to suppress this warning. Note that setting `ravel_inputs=False` will be forbidden in a future release."
253+
)
254+
inputs = grad_vars
239255

240256
self._pytensor_function = compile_pymc(inputs, outputs, givens=givens, **kwargs)
257+
self._raveled_inputs = ravel_inputs
241258

242259
def set_weights(self, values):
243260
if values.shape != (self._n_costs - 1,):
@@ -247,38 +264,29 @@ def set_weights(self, values):
247264
def set_extra_values(self, extra_vars):
248265
self._extra_are_set = True
249266
for var in self._extra_vars:
250-
self._extra_vars_shared[var.name].set_value(extra_vars[var.name])
267+
self._extra_vars_shared[var.name].set_value(extra_vars[var.name], borrow=True)
251268

252269
def get_extra_values(self):
253270
if not self._extra_are_set:
254271
raise ValueError("Extra values are not set.")
255272

256273
return {var.name: self._extra_vars_shared[var.name].get_value() for var in self._extra_vars}
257274

258-
def __call__(self, grad_vars, grad_out=None, extra_vars=None):
275+
def __call__(self, grad_vars, *, extra_vars=None):
259276
if extra_vars is not None:
260277
self.set_extra_values(extra_vars)
261-
262-
if not self._extra_are_set:
278+
elif not self._extra_are_set:
263279
raise ValueError("Extra values are not set.")
264280

265281
if isinstance(grad_vars, RaveledVars):
266-
grad_vars = list(DictToArrayBijection.rmap(grad_vars).values())
267-
268-
cost, *grads = self._pytensor_function(*grad_vars)
269-
270-
if grads:
271-
grads_raveled = DictToArrayBijection.map(
272-
{v.name: gv for v, gv in zip(self._grad_vars, grads)}
273-
)
274-
275-
if grad_out is None:
276-
return cost, grads_raveled.data
282+
if self._raveled_inputs:
283+
grad_vars = (grad_vars.data,)
277284
else:
278-
np.copyto(grad_out, grads_raveled.data)
279-
return cost
280-
else:
281-
return cost
285+
grad_vars = DictToArrayBijection.rmap(grad_vars).values()
286+
elif self._raveled_inputs and not isinstance(grad_vars, Sequence):
287+
grad_vars = (grad_vars,)
288+
289+
return self._pytensor_function(*grad_vars)
282290

283291
@property
284292
def profile(self):
@@ -521,7 +529,14 @@ def root(self):
521529
def isroot(self):
522530
return self.parent is None
523531

524-
def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
532+
def logp_dlogp_function(
533+
self,
534+
grad_vars=None,
535+
tempered=False,
536+
initial_point=None,
537+
ravel_inputs: bool | None = None,
538+
**kwargs,
539+
):
525540
"""Compile a PyTensor function that computes logp and gradient.
526541
527542
Parameters
@@ -547,13 +562,22 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
547562
costs = [self.logp()]
548563

549564
input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
550-
ip = self.initial_point(0)
565+
if initial_point is None:
566+
initial_point = self.initial_point(0)
551567
extra_vars_and_values = {
552-
var: ip[var.name]
568+
var: initial_point[var.name]
553569
for var in self.value_vars
554570
if var in input_vars and var not in grad_vars
555571
}
556-
return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)
572+
return ValueGradFunction(
573+
costs,
574+
grad_vars,
575+
extra_vars_and_values,
576+
model=self,
577+
initial_point=initial_point,
578+
ravel_inputs=ravel_inputs,
579+
**kwargs,
580+
)
557581

558582
def compile_logp(
559583
self,

pymc/sampling/mcmc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,8 @@ def init_nuts(
14411441
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
14421442
]
14431443

1444+
logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True)
1445+
logp_dlogp_func.trust_input = True
14441446
initial_points = _init_jitter(
14451447
model,
14461448
initvals,

pymc/step_methods/arraystep.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,17 @@ def __init__(
185185
model = modelcontext(model)
186186

187187
if logp_dlogp_func is None:
188-
func = model.logp_dlogp_function(vars, dtype=dtype, **pytensor_kwargs)
189-
else:
190-
func = logp_dlogp_func
191-
192-
self._logp_dlogp_func = func
188+
logp_dlogp_func = model.logp_dlogp_function(
189+
vars,
190+
dtype=dtype,
191+
ravel_inputs=True,
192+
**pytensor_kwargs,
193+
)
194+
logp_dlogp_func.trust_input = True
193195

194-
super().__init__(vars, func._extra_vars_shared, blocked, rng=rng)
196+
self._logp_dlogp_func = logp_dlogp_func
195197

196-
def step(self, point) -> tuple[PointType, StatsType]:
197-
self._logp_dlogp_func._extra_are_set = True
198-
return super().step(point)
198+
super().__init__(vars, logp_dlogp_func._extra_vars_shared, blocked, rng=rng)
199199

200200

201201
def metrop_select(

pymc/step_methods/hmc/base_hmc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
194194
process_start = time.process_time()
195195

196196
p0 = self.potential.random()
197-
p0 = RaveledVars(p0, q0.point_map_info)
198-
199197
start = self.integrator.compute_state(q0, p0)
200198

201199
warning: SamplerWarning | None = None
@@ -226,13 +224,13 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
226224
if self._step_rand is not None:
227225
step_size = self._step_rand(step_size, rng=self.rng)
228226

229-
hmc_step = self._hamiltonian_step(start, p0.data, step_size)
227+
hmc_step = self._hamiltonian_step(start, p0, step_size)
230228

231229
perf_end = time.perf_counter()
232230
process_end = time.process_time()
233231

234232
self.step_adapt.update(hmc_step.accept_stat, adapt_step)
235-
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
233+
self.potential.update(hmc_step.end.q.data, hmc_step.end.q_grad, self.tune)
236234
if hmc_step.divergence_info:
237235
info = hmc_step.divergence_info
238236
point = None

pymc/step_methods/hmc/integration.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919
from scipy import linalg
2020

21-
from pymc.blocking import RaveledVars
21+
from pymc.blocking import DictToArrayBijection, RaveledVars
2222
from pymc.step_methods.hmc.quadpotential import QuadPotential
2323

2424

2525
class State(NamedTuple):
2626
q: RaveledVars
27-
p: RaveledVars
27+
p: np.ndarray
2828
v: np.ndarray
2929
q_grad: np.ndarray
3030
energy: float
@@ -40,23 +40,35 @@ class CpuLeapfrogIntegrator:
4040
def __init__(self, potential: QuadPotential, logp_dlogp_func):
4141
"""Leapfrog integrator using CPU."""
4242
self._potential = potential
43-
self._logp_dlogp_func = logp_dlogp_func
44-
self._dtype = self._logp_dlogp_func.dtype
43+
# Sidestep logp_dlogp_function.__call__
44+
pytensor_function = logp_dlogp_func._pytensor_function
45+
# Create some wrappers for backwards compatibility during transition
46+
# When raveled_inputs=False is forbidden, func = pytensor_function
47+
if logp_dlogp_func._raveled_inputs:
48+
49+
def func(q, _):
50+
return pytensor_function(q)
51+
52+
else:
53+
54+
def func(q, point_map_info):
55+
unraveled_q = DictToArrayBijection.rmap(RaveledVars(q, point_map_info)).values()
56+
return pytensor_function(*unraveled_q)
57+
58+
self._logp_dlogp_func = func
59+
self._dtype = logp_dlogp_func.dtype
4560
if self._potential.dtype != self._dtype:
4661
raise ValueError(
4762
f"dtypes of potential ({self._potential.dtype}) and logp function ({self._dtype})"
4863
"don't match."
4964
)
5065

51-
def compute_state(self, q: RaveledVars, p: RaveledVars):
66+
def compute_state(self, q: RaveledVars, p: np.ndarray):
5267
"""Compute Hamiltonian functions using a position and momentum."""
53-
if q.data.dtype != self._dtype or p.data.dtype != self._dtype:
54-
raise ValueError(f"Invalid dtype. Must be {self._dtype}")
55-
56-
logp, dlogp = self._logp_dlogp_func(q)
68+
logp, dlogp = self._logp_dlogp_func(q.data, q.point_map_info)
5769

58-
v = self._potential.velocity(p.data, out=None)
59-
kinetic = self._potential.energy(p.data, velocity=v)
70+
v = self._potential.velocity(p, out=None)
71+
kinetic = self._potential.energy(p, velocity=v)
6072
energy = kinetic - logp
6173
return State(q, p, v, dlogp, energy, logp, 0)
6274

@@ -96,10 +108,10 @@ def _step(self, epsilon, state):
96108
axpy = linalg.blas.get_blas_funcs("axpy", dtype=self._dtype)
97109
pot = self._potential
98110

99-
q_new = state.q.data.copy()
100-
p_new = state.p.data.copy()
111+
q = state.q
112+
q_new = q.data.copy()
113+
p_new = state.p.copy()
101114
v_new = np.empty_like(q_new)
102-
q_new_grad = np.empty_like(q_new)
103115

104116
dt = 0.5 * epsilon
105117

@@ -112,19 +124,16 @@ def _step(self, epsilon, state):
112124
# q_new = q + epsilon * v_new
113125
axpy(v_new, q_new, a=epsilon)
114126

115-
p_new = RaveledVars(p_new, state.p.point_map_info)
116-
q_new = RaveledVars(q_new, state.q.point_map_info)
117-
118-
logp = self._logp_dlogp_func(q_new, grad_out=q_new_grad)
127+
logp, q_new_grad = self._logp_dlogp_func(q_new, q.point_map_info)
119128

120129
# p_new = p_new + dt * q_new_grad
121-
axpy(q_new_grad, p_new.data, a=dt)
130+
axpy(q_new_grad, p_new, a=dt)
122131

123-
kinetic = pot.velocity_energy(p_new.data, v_new)
132+
kinetic = pot.velocity_energy(p_new, v_new)
124133
energy = kinetic - logp
125134

126135
return State(
127-
q_new,
136+
RaveledVars(q_new, state.q.point_map_info),
128137
p_new,
129138
v_new,
130139
q_new_grad,

pymc/step_methods/hmc/nuts.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def __init__(
279279
self.log_accept_sum = -np.inf
280280
self.mean_tree_accept = 0.0
281281
self.n_proposals = 0
282-
self.p_sum = start.p.data.copy()
282+
self.p_sum = start.p.copy()
283283
self.max_energy_change = 0.0
284284

285285
def extend(self, direction):
@@ -330,9 +330,9 @@ def extend(self, direction):
330330
left, right = self.left, self.right
331331
p_sum = self.p_sum
332332
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
333-
p_sum1 = leftmost_p_sum + rightmost_begin.p.data
333+
p_sum1 = leftmost_p_sum + rightmost_begin.p
334334
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
335-
p_sum2 = leftmost_end.p.data + rightmost_p_sum
335+
p_sum2 = leftmost_end.p + rightmost_p_sum
336336
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
337337
turning = turning | turning1 | turning2
338338

@@ -372,7 +372,7 @@ def _single_step(self, left: State, epsilon: float):
372372
right.model_logp,
373373
right.index_in_trajectory,
374374
)
375-
tree = Subtree(right, right, right.p.data, proposal, log_size)
375+
tree = Subtree(right, right, right.p, proposal, log_size)
376376
return tree, None, False
377377
else:
378378
error_msg = f"Energy change in leapfrog step is too large: {energy_change}."
@@ -400,9 +400,9 @@ def _build_subtree(self, left, depth, epsilon):
400400
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
401401
# Additional U turn check only when depth > 1 to avoid redundant work.
402402
if depth - 1 > 0:
403-
p_sum1 = tree1.p_sum + tree2.left.p.data
403+
p_sum1 = tree1.p_sum + tree2.left.p
404404
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
405-
p_sum2 = tree1.right.p.data + tree2.p_sum
405+
p_sum2 = tree1.right.p + tree2.p_sum
406406
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
407407
turning = turning | turning1 | turning2
408408

tests/distributions/test_multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2395,7 +2395,7 @@ def test_mvnormal_no_cholesky_in_model_logp():
23952395
d2logp = m.compile_d2logp()
23962396
assert not contains_cholesky_op(d2logp.f.maker.fgraph)
23972397

2398-
logp_dlogp = m.logp_dlogp_function()
2398+
logp_dlogp = m.logp_dlogp_function(ravel_inputs=True)
23992399
assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph)
24002400

24012401

0 commit comments

Comments
 (0)