Skip to content

Commit e66a8fb

Browse files
ricardoV94twiecki
authored andcommitted
Fix HamiltonianMC returned datatype
1 parent 07a95fa commit e66a8fb

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pymc/step_methods/hmc/hmc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from pymc.step_methods.arraystep import Competence
1818
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
19-
from pymc.step_methods.hmc.integration import IntegrationError
19+
from pymc.step_methods.hmc.integration import IntegrationError, State
2020
from pymc.vartypes import discrete_types
2121

2222
__all__ = ["HamiltonianMC"]
@@ -161,6 +161,8 @@ def _hamiltonian_step(self, start, p0, step_size):
161161
"accepted": accepted,
162162
"model_logp": state.model_logp,
163163
}
164+
# Retrieve State q and p data from respective RaveledVars
165+
end = State(end.q.data, end.p.data, end.v, end.q_grad, end.energy, end.model_logp)
164166
return HMCStepData(end, accept_stat, div_info, stats)
165167

166168
@staticmethod

0 commit comments

Comments
 (0)