Skip to content

Commit 2813811

Browse files
committed
Simplify logp_dlogp_function
1 parent 61ad1ca commit 2813811

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pymc/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
614614
`alpha` can be changed using `ValueGradFunction.set_weights([alpha])`.
615615
"""
616616
if grad_vars is None:
617-
grad_vars = [self.rvs_to_values[v] for v in typefilter(self.free_RVs, continuous_types)]
617+
grad_vars = self.continuous_value_vars
618618
else:
619619
for i, var in enumerate(grad_vars):
620620
if var.dtype not in continuous_types:
@@ -626,10 +626,11 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
626626
costs = [self.logp()]
627627

628628
input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
629-
extra_vars = [self.rvs_to_values.get(var, var) for var in self.free_RVs]
630629
ip = self.initial_point(0)
631630
extra_vars_and_values = {
632-
var: ip[var.name] for var in extra_vars if var in input_vars and var not in grad_vars
631+
var: ip[var.name]
632+
for var in self.value_vars
633+
if var in input_vars and var not in grad_vars
633634
}
634635
return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)
635636

0 commit comments

Comments
 (0)