Skip to content

Commit 73d4267

Browse files
Update infer_shape signatures
1 parent 6c5f09e commit 73d4267

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ def perform(self, node, inputs, outputs):
724724
pm._log.exception("Failed to check if %s positive definite", x)
725725
raise
726726

727-
def infer_shape(self, node, shapes):
727+
def infer_shape(self, fgraph, node, shapes):
728728
return [[]]
729729

730730
def grad(self, inp, grads):

pymc3/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def grad(self, inputs, gout):
363363
idx = tt.arange(gz.shape[-1])
364364
return [gz[..., idx, idx]]
365365

366-
def infer_shape(self, nodes, shapes):
366+
def infer_shape(self, fgraph, nodes, shapes):
367367
return [(shapes[0][0],) + (shapes[0][1],) * 2]
368368

369369

@@ -422,7 +422,7 @@ def grad(self, inputs, gout):
422422
]
423423
return [gout[0][slc] for slc in slices]
424424

425-
def infer_shape(self, nodes, shapes):
425+
def infer_shape(self, fgraph, nodes, shapes):
426426
first, second = zip(*shapes)
427427
return [(tt.add(*first), tt.add(*second))]
428428

pymc3/ode/ode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def perform(self, node, inputs_storage, output_storage):
213213
# simulate states and sensitivities in one forward pass
214214
output_storage[0][0], output_storage[1][0] = self._simulate(y0, theta)
215215

216-
def infer_shape(self, node, input_shapes):
216+
def infer_shape(self, fgraph, node, input_shapes):
217217
s_y0, s_theta = input_shapes
218218
output_shapes = [(self.n_times, self.n_states), (self.n_times, self.n_states, self.n_p)]
219219
return output_shapes

0 commit comments

Comments
 (0)