diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 286e508bf2..8034185283 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1626,7 +1626,7 @@ def support_point(rv, size, b, kappa, mu): def logp(value, b, kappa, mu): value = value - mu res = pt.log(b / (kappa + (kappa**-1))) + ( - -value * b * pt.sgn(value) * (kappa ** pt.sgn(value)) + -value * b * pt.sign(value) * (kappa ** pt.sign(value)) ) return check_parameters( diff --git a/pymc/distributions/moments/means.py b/pymc/distributions/moments/means.py index 450a90f8a6..f025733726 100644 --- a/pymc/distributions/moments/means.py +++ b/pymc/distributions/moments/means.py @@ -50,7 +50,7 @@ UniformRV, VonMisesRV, ) -from pytensor.tensor.var import TensorVariable +from pytensor.tensor.variable import TensorVariable from pymc.distributions.continuous import ( AsymmetricLaplaceRV, diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index efd8d1f778..dcfb2e3d5b 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -461,7 +461,7 @@ def implicit_size_from_params( for param, ndim in zip(params, ndims_params): batch_shape = list(param.shape[:-ndim] if ndim > 0 else param.shape) # Overwrite broadcastable dims - for i, broadcastable in enumerate(param.type.broadcastable): + for i, broadcastable in enumerate(param.type.broadcastable[: len(batch_shape)]): if broadcastable: batch_shape[i] = 1 batch_shapes.append(batch_shape) diff --git a/tests/distributions/test_shape_utils.py b/tests/distributions/test_shape_utils.py index 493d7cb8c6..58f75e1cb8 100644 --- a/tests/distributions/test_shape_utils.py +++ b/tests/distributions/test_shape_utils.py @@ -35,9 +35,11 @@ convert_size, get_support_shape, get_support_shape_1d, + implicit_size_from_params, rv_size_is_none, ) from pymc.model import Model +from pymc.pytensorf import constant_fold test_shapes = [ ((), (1,), (4,), (5, 4)), @@ -630,3 +632,10 @@ def test_get_support_shape( assert (f() == expected_support_shape).all() with pytest.raises(AssertionError, match="support_shape does not match"): inferred_support_shape.eval() + + +def test_implicit_size_from_params(): + x = pt.tensor(shape=(5, 1)) + y = pt.tensor(shape=(3, 3)) + res = implicit_size_from_params(x, y, ndims_params=[1, 2]) + assert constant_fold([res]) == (5,)