Skip to content

Commit 03cab87

Browse files
committed
Default to JAX test mode in random tests
1 parent c1ecbe0 commit 03cab87

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

tests/link/jax/test_random.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2727

2828

29-
def compile_random_function(*args, mode="JAX", **kwargs):
29+
def compile_random_function(*args, mode=jax_mode, **kwargs):
3030
with pytest.warns(
3131
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
3232
):
@@ -41,7 +41,7 @@ def test_random_RandomStream():
4141
srng = RandomStream(seed=123)
4242
out = srng.normal() - srng.normal()
4343

44-
fn = compile_random_function([], out, mode=jax_mode)
44+
fn = compile_random_function([], out)
4545
jax_res_1 = fn()
4646
jax_res_2 = fn()
4747

@@ -54,7 +54,7 @@ def test_random_updates(rng_ctor):
5454
rng = shared(original_value, name="original_rng", borrow=False)
5555
next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs
5656

57-
f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
57+
f = compile_random_function([], [x], updates={rng: next_rng})
5858
assert f() != f()
5959

6060
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -482,7 +482,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
482482
)
483483
rng = shared(np.random.default_rng(29403))
484484
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng)
485-
g_fn = compile_random_function(dist_params, g, mode=jax_mode)
485+
g_fn = compile_random_function(dist_params, g)
486486
samples = g_fn(*test_values)
487487

488488
bcast_dist_args = np.broadcast_arrays(*test_values)
@@ -518,7 +518,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
518518
param_that_implies_size = pt.matrix("param_that_implies_size", shape=(None, None))
519519

520520
rv = rv_fn(param_that_implies_size)
521-
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}, mode=jax_mode)
521+
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))})
522522

523523
assert draws.shape == (2, 2)
524524
assert np.unique(draws).size == 4
@@ -528,7 +528,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
528528
def test_random_bernoulli(size):
529529
rng = shared(np.random.default_rng(123))
530530
g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
531-
g_fn = compile_random_function([], g, mode=jax_mode)
531+
g_fn = compile_random_function([], g)
532532
samples = g_fn()
533533
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
534534

@@ -539,7 +539,7 @@ def test_random_mvnormal():
539539
mu = np.ones(4)
540540
cov = np.eye(4)
541541
g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
542-
g_fn = compile_random_function([], g, mode=jax_mode)
542+
g_fn = compile_random_function([], g)
543543
samples = g_fn()
544544
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
545545

@@ -559,7 +559,7 @@ def test_random_mvnormal():
559559
def test_random_dirichlet(parameter, size):
560560
rng = shared(np.random.default_rng(123))
561561
g = pt.random.dirichlet(parameter, size=(1000, *size), rng=rng)
562-
g_fn = compile_random_function([], g, mode=jax_mode)
562+
g_fn = compile_random_function([], g)
563563
samples = g_fn()
564564
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
565565

@@ -568,7 +568,7 @@ def test_random_choice():
568568
# `replace=True` and `p is None`
569569
rng = shared(np.random.default_rng(123))
570570
g = pt.random.choice(np.arange(4), size=10_000, rng=rng)
571-
g_fn = compile_random_function([], g, mode=jax_mode)
571+
g_fn = compile_random_function([], g)
572572
samples = g_fn()
573573
assert samples.shape == (10_000,)
574574
# Elements are picked at equal frequency
@@ -577,7 +577,7 @@ def test_random_choice():
577577
# `replace=True` and `p is not None`
578578
rng = shared(np.random.default_rng(123))
579579
g = pt.random.choice(4, p=np.array([0.0, 0.5, 0.0, 0.5]), size=(5, 2), rng=rng)
580-
g_fn = compile_random_function([], g, mode=jax_mode)
580+
g_fn = compile_random_function([], g)
581581
samples = g_fn()
582582
assert samples.shape == (5, 2)
583583
# Only odd numbers are picked
@@ -586,7 +586,7 @@ def test_random_choice():
586586
# `replace=False` and `p is None`
587587
rng = shared(np.random.default_rng(123))
588588
g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng)
589-
g_fn = compile_random_function([], g, mode=jax_mode)
589+
g_fn = compile_random_function([], g)
590590
samples = g_fn()
591591
assert samples.shape == (2, 49)
592592
# Elements are unique
@@ -601,7 +601,7 @@ def test_random_choice():
601601
rng=rng,
602602
replace=False,
603603
)
604-
g_fn = compile_random_function([], g, mode=jax_mode)
604+
g_fn = compile_random_function([], g)
605605
samples = g_fn()
606606
assert samples.shape == (3,)
607607
# Elements are unique
@@ -613,14 +613,14 @@ def test_random_choice():
613613
def test_random_categorical():
614614
rng = shared(np.random.default_rng(123))
615615
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
616-
g_fn = compile_random_function([], g, mode=jax_mode)
616+
g_fn = compile_random_function([], g)
617617
samples = g_fn()
618618
assert samples.shape == (10000, 4)
619619
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
620620

621621
# Test zero probabilities
622622
g = pt.random.categorical([0, 0.5, 0, 0.5], size=(1000,), rng=rng)
623-
g_fn = compile_random_function([], g, mode=jax_mode)
623+
g_fn = compile_random_function([], g)
624624
samples = g_fn()
625625
assert samples.shape == (1000,)
626626
assert np.all(samples % 2 == 1)
@@ -630,7 +630,7 @@ def test_random_permutation():
630630
array = np.arange(4)
631631
rng = shared(np.random.default_rng(123))
632632
g = pt.random.permutation(array, rng=rng)
633-
g_fn = compile_random_function([], g, mode=jax_mode)
633+
g_fn = compile_random_function([], g)
634634
permuted = g_fn()
635635
with pytest.raises(AssertionError):
636636
np.testing.assert_allclose(array, permuted)
@@ -653,7 +653,7 @@ def test_random_geometric():
653653
rng = shared(np.random.default_rng(123))
654654
p = np.array([0.3, 0.7])
655655
g = pt.random.geometric(p, size=(10_000, 2), rng=rng)
656-
g_fn = compile_random_function([], g, mode=jax_mode)
656+
g_fn = compile_random_function([], g)
657657
samples = g_fn()
658658
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
659659
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
@@ -664,7 +664,7 @@ def test_negative_binomial():
664664
n = np.array([10, 40])
665665
p = np.array([0.3, 0.7])
666666
g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
667-
g_fn = compile_random_function([], g, mode=jax_mode)
667+
g_fn = compile_random_function([], g)
668668
samples = g_fn()
669669
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
670670
np.testing.assert_allclose(
@@ -678,7 +678,7 @@ def test_binomial():
678678
n = np.array([10, 40])
679679
p = np.array([0.3, 0.7])
680680
g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng)
681-
g_fn = compile_random_function([], g, mode=jax_mode)
681+
g_fn = compile_random_function([], g)
682682
samples = g_fn()
683683
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
684684
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
@@ -693,7 +693,7 @@ def test_beta_binomial():
693693
a = np.array([1.5, 13])
694694
b = np.array([0.5, 9])
695695
g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
696-
g_fn = compile_random_function([], g, mode=jax_mode)
696+
g_fn = compile_random_function([], g)
697697
samples = g_fn()
698698
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
699699
np.testing.assert_allclose(
@@ -754,7 +754,7 @@ def test_vonmises_mu_outside_circle():
754754
mu = np.array([-30, 40])
755755
kappa = np.array([100, 10])
756756
g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
757-
g_fn = compile_random_function([], g, mode=jax_mode)
757+
g_fn = compile_random_function([], g)
758758
samples = g_fn()
759759
np.testing.assert_allclose(
760760
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
@@ -850,15 +850,15 @@ def test_random_concrete_shape():
850850
rng = shared(np.random.default_rng(123))
851851
x_pt = pt.dmatrix()
852852
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
853-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
853+
jax_fn = compile_random_function([x_pt], out)
854854
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
855855

856856

857857
def test_random_concrete_shape_from_param():
858858
rng = shared(np.random.default_rng(123))
859859
x_pt = pt.dmatrix()
860860
out = pt.random.normal(x_pt, 1, rng=rng)
861-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
861+
jax_fn = compile_random_function([x_pt], out)
862862
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
863863

864864

@@ -877,7 +877,7 @@ def test_random_concrete_shape_subtensor():
877877
rng = shared(np.random.default_rng(123))
878878
x_pt = pt.dmatrix()
879879
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
880-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
880+
jax_fn = compile_random_function([x_pt], out)
881881
assert jax_fn(np.ones((2, 3))).shape == (3,)
882882

883883

@@ -893,7 +893,7 @@ def test_random_concrete_shape_subtensor_tuple():
893893
rng = shared(np.random.default_rng(123))
894894
x_pt = pt.dmatrix()
895895
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
896-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
896+
jax_fn = compile_random_function([x_pt], out)
897897
assert jax_fn(np.ones((2, 3))).shape == (2,)
898898

899899

@@ -904,7 +904,7 @@ def test_random_concrete_shape_graph_input():
904904
rng = shared(np.random.default_rng(123))
905905
size_pt = pt.scalar()
906906
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
907-
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
907+
jax_fn = compile_random_function([size_pt], out)
908908
assert jax_fn(10).shape == (10,)
909909

910910

0 commit comments

Comments
 (0)