Skip to content

Commit 31370f3

Browse files
ferrinetwiecki
authored andcommitted
add evaluate over trace (#3244)
* add evaluate over trace * Good remark Co-Authored-By: ferrine <[email protected]>
1 parent 54a434a commit 31370f3

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

pymc3/variational/approximations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,26 @@ def __init__(self, trace=None, size=None, **kwargs):
570570
raise opvi.LocalGroupError('Empirical approximation does not support local variables')
571571
super(Empirical, self).__init__(trace=trace, size=size, **kwargs)
572572

573+
def evaluate_over_trace(self, node):
574+
R"""
575+
This allows to statically evaluate any symbolic expression over the trace.
576+
577+
Parameters
578+
----------
579+
node : Theano Variables (or Theano expressions)
580+
581+
Returns
582+
-------
583+
evaluated node(s) over the posterior trace contained in the empirical approximation
584+
"""
585+
node = self.to_flat_input(node)
586+
587+
def sample(post):
588+
return theano.clone(node, {self.input: post})
589+
590+
nodes, _ = theano.scan(sample, self.histogram)
591+
return nodes
592+
573593

574594
class NormalizingFlow(SingleGroupApproximation):
575595
__doc__ = """**Single Group Normalizing Flow Approximation**

0 commit comments

Comments
 (0)