26
26
from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
27
27
28
28
29
- def compile_random_function (* args , mode = "JAX" , ** kwargs ):
29
+ def compile_random_function (* args , mode = jax_mode , ** kwargs ):
30
30
with pytest .warns (
31
31
UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
32
32
):
@@ -41,7 +41,7 @@ def test_random_RandomStream():
41
41
srng = RandomStream (seed = 123 )
42
42
out = srng .normal () - srng .normal ()
43
43
44
- fn = compile_random_function ([], out , mode = jax_mode )
44
+ fn = compile_random_function ([], out )
45
45
jax_res_1 = fn ()
46
46
jax_res_2 = fn ()
47
47
@@ -54,7 +54,7 @@ def test_random_updates(rng_ctor):
54
54
rng = shared (original_value , name = "original_rng" , borrow = False )
55
55
next_rng , x = pt .random .normal (name = "x" , rng = rng ).owner .outputs
56
56
57
- f = compile_random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
57
+ f = compile_random_function ([], [x ], updates = {rng : next_rng })
58
58
assert f () != f ()
59
59
60
60
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -482,7 +482,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
482
482
)
483
483
rng = shared (np .random .default_rng (29403 ))
484
484
g = rv_op (* dist_params , size = (10000 , * base_size ), rng = rng )
485
- g_fn = compile_random_function (dist_params , g , mode = jax_mode )
485
+ g_fn = compile_random_function (dist_params , g )
486
486
samples = g_fn (* test_values )
487
487
488
488
bcast_dist_args = np .broadcast_arrays (* test_values )
@@ -518,7 +518,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
518
518
param_that_implies_size = pt .matrix ("param_that_implies_size" , shape = (None , None ))
519
519
520
520
rv = rv_fn (param_that_implies_size )
521
- draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))}, mode = jax_mode )
521
+ draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))})
522
522
523
523
assert draws .shape == (2 , 2 )
524
524
assert np .unique (draws ).size == 4
@@ -528,7 +528,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
528
528
def test_random_bernoulli (size ):
529
529
rng = shared (np .random .default_rng (123 ))
530
530
g = pt .random .bernoulli (0.5 , size = (1000 , * size ), rng = rng )
531
- g_fn = compile_random_function ([], g , mode = jax_mode )
531
+ g_fn = compile_random_function ([], g )
532
532
samples = g_fn ()
533
533
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
534
534
@@ -539,7 +539,7 @@ def test_random_mvnormal():
539
539
mu = np .ones (4 )
540
540
cov = np .eye (4 )
541
541
g = pt .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
542
- g_fn = compile_random_function ([], g , mode = jax_mode )
542
+ g_fn = compile_random_function ([], g )
543
543
samples = g_fn ()
544
544
np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
545
545
@@ -559,7 +559,7 @@ def test_random_mvnormal():
559
559
def test_random_dirichlet (parameter , size ):
560
560
rng = shared (np .random .default_rng (123 ))
561
561
g = pt .random .dirichlet (parameter , size = (1000 , * size ), rng = rng )
562
- g_fn = compile_random_function ([], g , mode = jax_mode )
562
+ g_fn = compile_random_function ([], g )
563
563
samples = g_fn ()
564
564
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
565
565
@@ -568,7 +568,7 @@ def test_random_choice():
568
568
# `replace=True` and `p is None`
569
569
rng = shared (np .random .default_rng (123 ))
570
570
g = pt .random .choice (np .arange (4 ), size = 10_000 , rng = rng )
571
- g_fn = compile_random_function ([], g , mode = jax_mode )
571
+ g_fn = compile_random_function ([], g )
572
572
samples = g_fn ()
573
573
assert samples .shape == (10_000 ,)
574
574
# Elements are picked at equal frequency
@@ -577,7 +577,7 @@ def test_random_choice():
577
577
# `replace=True` and `p is not None`
578
578
rng = shared (np .random .default_rng (123 ))
579
579
g = pt .random .choice (4 , p = np .array ([0.0 , 0.5 , 0.0 , 0.5 ]), size = (5 , 2 ), rng = rng )
580
- g_fn = compile_random_function ([], g , mode = jax_mode )
580
+ g_fn = compile_random_function ([], g )
581
581
samples = g_fn ()
582
582
assert samples .shape == (5 , 2 )
583
583
# Only odd numbers are picked
@@ -586,7 +586,7 @@ def test_random_choice():
586
586
# `replace=False` and `p is None`
587
587
rng = shared (np .random .default_rng (123 ))
588
588
g = pt .random .choice (np .arange (100 ), replace = False , size = (2 , 49 ), rng = rng )
589
- g_fn = compile_random_function ([], g , mode = jax_mode )
589
+ g_fn = compile_random_function ([], g )
590
590
samples = g_fn ()
591
591
assert samples .shape == (2 , 49 )
592
592
# Elements are unique
@@ -601,7 +601,7 @@ def test_random_choice():
601
601
rng = rng ,
602
602
replace = False ,
603
603
)
604
- g_fn = compile_random_function ([], g , mode = jax_mode )
604
+ g_fn = compile_random_function ([], g )
605
605
samples = g_fn ()
606
606
assert samples .shape == (3 ,)
607
607
# Elements are unique
@@ -613,14 +613,14 @@ def test_random_choice():
613
613
def test_random_categorical ():
614
614
rng = shared (np .random .default_rng (123 ))
615
615
g = pt .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
616
- g_fn = compile_random_function ([], g , mode = jax_mode )
616
+ g_fn = compile_random_function ([], g )
617
617
samples = g_fn ()
618
618
assert samples .shape == (10000 , 4 )
619
619
np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
620
620
621
621
# Test zero probabilities
622
622
g = pt .random .categorical ([0 , 0.5 , 0 , 0.5 ], size = (1000 ,), rng = rng )
623
- g_fn = compile_random_function ([], g , mode = jax_mode )
623
+ g_fn = compile_random_function ([], g )
624
624
samples = g_fn ()
625
625
assert samples .shape == (1000 ,)
626
626
assert np .all (samples % 2 == 1 )
@@ -630,7 +630,7 @@ def test_random_permutation():
630
630
array = np .arange (4 )
631
631
rng = shared (np .random .default_rng (123 ))
632
632
g = pt .random .permutation (array , rng = rng )
633
- g_fn = compile_random_function ([], g , mode = jax_mode )
633
+ g_fn = compile_random_function ([], g )
634
634
permuted = g_fn ()
635
635
with pytest .raises (AssertionError ):
636
636
np .testing .assert_allclose (array , permuted )
@@ -653,7 +653,7 @@ def test_random_geometric():
653
653
rng = shared (np .random .default_rng (123 ))
654
654
p = np .array ([0.3 , 0.7 ])
655
655
g = pt .random .geometric (p , size = (10_000 , 2 ), rng = rng )
656
- g_fn = compile_random_function ([], g , mode = jax_mode )
656
+ g_fn = compile_random_function ([], g )
657
657
samples = g_fn ()
658
658
np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
659
659
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -664,7 +664,7 @@ def test_negative_binomial():
664
664
n = np .array ([10 , 40 ])
665
665
p = np .array ([0.3 , 0.7 ])
666
666
g = pt .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
667
- g_fn = compile_random_function ([], g , mode = jax_mode )
667
+ g_fn = compile_random_function ([], g )
668
668
samples = g_fn ()
669
669
np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
670
670
np .testing .assert_allclose (
@@ -678,7 +678,7 @@ def test_binomial():
678
678
n = np .array ([10 , 40 ])
679
679
p = np .array ([0.3 , 0.7 ])
680
680
g = pt .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
681
- g_fn = compile_random_function ([], g , mode = jax_mode )
681
+ g_fn = compile_random_function ([], g )
682
682
samples = g_fn ()
683
683
np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
684
684
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -693,7 +693,7 @@ def test_beta_binomial():
693
693
a = np .array ([1.5 , 13 ])
694
694
b = np .array ([0.5 , 9 ])
695
695
g = pt .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
696
- g_fn = compile_random_function ([], g , mode = jax_mode )
696
+ g_fn = compile_random_function ([], g )
697
697
samples = g_fn ()
698
698
np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
699
699
np .testing .assert_allclose (
@@ -754,7 +754,7 @@ def test_vonmises_mu_outside_circle():
754
754
mu = np .array ([- 30 , 40 ])
755
755
kappa = np .array ([100 , 10 ])
756
756
g = pt .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
757
- g_fn = compile_random_function ([], g , mode = jax_mode )
757
+ g_fn = compile_random_function ([], g )
758
758
samples = g_fn ()
759
759
np .testing .assert_allclose (
760
760
samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -850,15 +850,15 @@ def test_random_concrete_shape():
850
850
rng = shared (np .random .default_rng (123 ))
851
851
x_pt = pt .dmatrix ()
852
852
out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
853
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
853
+ jax_fn = compile_random_function ([x_pt ], out )
854
854
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
855
855
856
856
857
857
def test_random_concrete_shape_from_param ():
858
858
rng = shared (np .random .default_rng (123 ))
859
859
x_pt = pt .dmatrix ()
860
860
out = pt .random .normal (x_pt , 1 , rng = rng )
861
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
861
+ jax_fn = compile_random_function ([x_pt ], out )
862
862
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
863
863
864
864
@@ -877,7 +877,7 @@ def test_random_concrete_shape_subtensor():
877
877
rng = shared (np .random .default_rng (123 ))
878
878
x_pt = pt .dmatrix ()
879
879
out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
880
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
880
+ jax_fn = compile_random_function ([x_pt ], out )
881
881
assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
882
882
883
883
@@ -893,7 +893,7 @@ def test_random_concrete_shape_subtensor_tuple():
893
893
rng = shared (np .random .default_rng (123 ))
894
894
x_pt = pt .dmatrix ()
895
895
out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
896
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
896
+ jax_fn = compile_random_function ([x_pt ], out )
897
897
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
898
898
899
899
@@ -904,7 +904,7 @@ def test_random_concrete_shape_graph_input():
904
904
rng = shared (np .random .default_rng (123 ))
905
905
size_pt = pt .scalar ()
906
906
out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
907
- jax_fn = compile_random_function ([size_pt ], out , mode = jax_mode )
907
+ jax_fn = compile_random_function ([size_pt ], out )
908
908
assert jax_fn (10 ).shape == (10 ,)
909
909
910
910
0 commit comments