Skip to content

Commit b3da2a4

Browse files
authored
Lazy import of scipy.stats (#1268)
1 parent 00fea0e commit b3da2a4

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

pytensor/tensor/random/basic.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Literal
44

55
import numpy as np
6-
import scipy.stats as stats
76
from numpy import broadcast_shapes as np_broadcast_shapes
87
from numpy import einsum as np_einsum
98
from numpy import sqrt as np_sqrt
@@ -21,6 +20,11 @@
2120
)
2221

2322

23+
# Scipy.stats is considerably slow to import
24+
# We import scipy.stats lazily inside `ScipyRandomVariable`
25+
stats = None
26+
27+
2428
try:
2529
broadcast_shapes = np.broadcast_shapes
2630
except AttributeError:
@@ -57,6 +61,9 @@ def rng_fn_scipy(cls, rng, *args, **kwargs):
5761

5862
@classmethod
5963
def rng_fn(cls, *args, **kwargs):
64+
global stats
65+
if stats is None:
66+
import scipy.stats as stats
6067
size = args[-1]
6168
res = cls.rng_fn_scipy(*args, **kwargs)
6269

0 commit comments

Comments
 (0)