Skip to content

Commit 7420495

Browse files
authored
Closes #3051 - Allows numpy array input to logp
Allows `logp(self,value)` to take `value` input of type numpy array without errors
1 parent 9418001 commit 7420495

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc3/distributions/multivariate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from theano.tensor.slinalg import Cholesky
2828
import pymc3 as pm
2929

30-
from pymc3.theanof import floatX
30+
from pymc3.theanof import floatX, intX
3131
from . import transforms
3232
from pymc3.util import get_variable_name
3333
from .distribution import (Continuous, Discrete, draw_values, generate_samples,
@@ -327,7 +327,7 @@ def logp(self, value):
327327
TensorVariable
328328
"""
329329
quaddist, logdet, ok = self._quaddist(value)
330-
k = theano.shared(value).shape[-1].astype(theano.config.floatX)
330+
k = intX(value.shape[-1]).astype(theano.config.floatX)
331331
norm = - 0.5 * k * pm.floatX(np.log(2 * np.pi))
332332
return bound(norm - 0.5 * quaddist - logdet, ok)
333333

@@ -441,7 +441,7 @@ def logp(self, value):
441441
TensorVariable
442442
"""
443443
quaddist, logdet, ok = self._quaddist(value)
444-
k = theano.shared(value).shape[-1].astype(theano.config.floatX)
444+
k = intX(value.shape[-1]).astype(theano.config.floatX)
445445

446446
norm = (gammaln((self.nu + k) / 2.)
447447
- gammaln(self.nu / 2.)

0 commit comments

Comments
 (0)