Skip to content

Commit f12bea6

Browse files
authored
Handle MvNormal method in Op call (#1252)
1 parent 7f03125 commit f12bea6

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

pytensor/tensor/random/basic.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs):
865865
)
866866
self.method = method
867867

868-
def __call__(self, mean, cov, size=None, **kwargs):
868+
def __call__(self, mean, cov, size=None, method=None, **kwargs):
869869
r""" "Draw samples from a multivariate normal distribution.
870870
871871
Signature
@@ -888,6 +888,12 @@ def __call__(self, mean, cov, size=None, **kwargs):
888888
is specified, a single `N`-dimensional sample is returned.
889889
890890
"""
891+
if method is not None and method != self.method:
892+
# Recreate Op with the new method
893+
props = self._props_dict()
894+
props["method"] = method
895+
new_op = type(self)(**props)
896+
return new_op.__call__(mean, cov, size=size, method=method, **kwargs)
891897
return super().__call__(mean, cov, size=size, **kwargs)
892898

893899
def rng_fn(self, rng, mean, cov, size):

tests/tensor/random/test_basic.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pytensor.tensor import ones, stack
2020
from pytensor.tensor.random.basic import (
2121
ChoiceWithoutReplacement,
22-
MvNormalRV,
2322
PermutationRV,
2423
_gamma,
2524
bernoulli,
@@ -707,7 +706,7 @@ def test_mvnormal_cov_decomposition_method(method, psd):
707706
[0, 0, 0],
708707
]
709708
rng = shared(np.random.default_rng(675))
710-
draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,))
709+
draws = multivariate_normal(mean, cov, method=method, size=(10_000,), rng=rng)
711710
assert draws.owner.op.method == method
712711

713712
# JAX doesn't raise errors at runtime

0 commit comments

Comments
 (0)