diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 10d5343511..58b90fdc6c 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -865,7 +865,7 @@ def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs): ) self.method = method - def __call__(self, mean, cov, size=None, **kwargs): + def __call__(self, mean, cov, size=None, method=None, **kwargs): r""" "Draw samples from a multivariate normal distribution. Signature @@ -888,6 +888,12 @@ def __call__(self, mean, cov, size=None, **kwargs): is specified, a single `N`-dimensional sample is returned. """ + if method is not None and method != self.method: + # Recreate Op with the new method + props = self._props_dict() + props["method"] = method + new_op = type(self)(**props) + return new_op.__call__(mean, cov, size=size, method=method, **kwargs) return super().__call__(mean, cov, size=size, **kwargs) def rng_fn(self, rng, mean, cov, size): diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 4192a6c473..d7167b6a61 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -19,7 +19,6 @@ from pytensor.tensor import ones, stack from pytensor.tensor.random.basic import ( ChoiceWithoutReplacement, - MvNormalRV, PermutationRV, _gamma, bernoulli, @@ -707,7 +706,7 @@ def test_mvnormal_cov_decomposition_method(method, psd): [0, 0, 0], ] rng = shared(np.random.default_rng(675)) - draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,)) + draws = multivariate_normal(mean, cov, method=method, size=(10_000,), rng=rng) assert draws.owner.op.method == method # JAX doesn't raise errors at runtime