Skip to content

Commit a87d95e

Browse files
committed
Add infer_shape to MinibatchRandomVariable
1 parent 3a718f2 commit a87d95e

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pymc/variational/minibatch_rv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def make_node(self, rv, *total_size):
4040
out = rv.type()
4141
return Apply(self, [rv, *total_size], [out])
4242

43+
def infer_shape(self, fgraph, node, shapes):
44+
return [shapes[0]]
45+
4346
def perform(self, node, inputs, output_storage):
4447
output_storage[0][0] = inputs[0]
4548

0 commit comments

Comments
 (0)