Skip to content

Commit 815b258

Browse files
committed
Benchmark minimal random function call
1 parent 6e57a08 commit 815b258

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/compile/function/test_types.py

+14
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from pytensor.printing import debugprint
2020
from pytensor.tensor.math import dot, tanh
2121
from pytensor.tensor.math import sum as pt_sum
22+
from pytensor.tensor.random import normal
23+
from pytensor.tensor.random.type import random_generator_type
2224
from pytensor.tensor.type import (
2325
dmatrix,
2426
dscalar,
@@ -1280,3 +1282,15 @@ def test_empty_givens_updates():
12801282
y = x * 2
12811283
function([In(x)], y, givens={})
12821284
function([In(x)], y, updates={})
1285+
1286+
1287+
@pytest.mark.parametrize("trust_input", [True, False])
1288+
def test_minimal_random_function_call_benchmark(trust_input, benchmark):
1289+
rng = random_generator_type()
1290+
x = normal(rng=rng, size=(100,))
1291+
1292+
f = function([In(rng, mutable=True)], x)
1293+
f.trust_input = trust_input
1294+
1295+
rng_val = np.random.default_rng()
1296+
benchmark(f, rng_val)

0 commit comments

Comments
 (0)