Skip to content

Commit afad405

Browse files
committed
fix the bug
1 parent b7200d7 commit afad405

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

pymc/logprob/transforms.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
4141

4242
import numpy as np
43+
import pytensor
4344
import pytensor.tensor as pt
4445

4546
from pytensor import scan
@@ -959,7 +960,8 @@ class SimplexTransform(RVTransform):
959960

960961
def forward(self, value, *inputs):
961962
log_value = pt.log(value)
962-
shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1]
963+
N = value.shape[-1].astype(pytensor.config.floatX)
964+
shift = pt.sum(log_value, -1, keepdims=True) / N
963965
return log_value[..., :-1] - shift
964966

965967
def backward(self, value, *inputs):
@@ -969,6 +971,7 @@ def backward(self, value, *inputs):
969971

970972
def log_jac_det(self, value, *inputs):
971973
N = value.shape[-1] + 1
974+
N = N.astype(pytensor.config.floatX)
972975
sum_value = pt.sum(value, -1, keepdims=True)
973976
value_sum_expanded = value + sum_value
974977
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1)

tests/distributions/test_multivariate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,13 +1716,13 @@ class TestDirichlet(BaseTestDistributionRandom):
17161716

17171717
@pytensor.config.change_flags(warn_float64="raise", floatX="float32")
17181718
def test_dirichlet_float32(self):
1719-
"""https://github.com/pymc-devs/pymc/issues/6779
1720-
"""
1719+
"""https://github.com/pymc-devs/pymc/issues/6779"""
17211720
with pm.Model() as model:
17221721
c = pm.floatX([1, 1, 1])
17231722
pm.Dirichlet("a", c)
17241723
model.point_logps()
17251724

1725+
17261726
class TestMultinomial(BaseTestDistributionRandom):
17271727
pymc_dist = pm.Multinomial
17281728
pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
@@ -2134,4 +2134,3 @@ def test_posdef_symmetric(matrix, result):
21342134
"""
21352135
data = np.array(matrix, dtype=pytensor.config.floatX)
21362136
assert posdef(data) == result
2137-

0 commit comments

Comments
 (0)