diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 58b90fdc6c..214a7bdd3d 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -3,7 +3,6 @@ from typing import Literal import numpy as np -import scipy.stats as stats from numpy import broadcast_shapes as np_broadcast_shapes from numpy import einsum as np_einsum from numpy import sqrt as np_sqrt @@ -21,6 +20,11 @@ ) +# Scipy.stats is considerably slow to import +# We import scipy.stats lazily inside `ScipyRandomVariable` +stats = None + + try: broadcast_shapes = np.broadcast_shapes except AttributeError: @@ -57,6 +61,9 @@ def rng_fn_scipy(cls, rng, *args, **kwargs): @classmethod def rng_fn(cls, *args, **kwargs): + global stats + if stats is None: + import scipy.stats as stats size = args[-1] res = cls.rng_fn_scipy(*args, **kwargs)