@@ -596,91 +596,88 @@ def test_rvs_to_value_vars_nested():
596
596
assert equal_computations (before , after )
597
597
598
598
599
- def test_check_bounds_flag ():
600
- """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc"""
601
- logp = at .ones (3 )
602
- cond = np .array ([1 , 0 , 1 ])
603
- bound = check_parameters (logp , cond )
604
-
605
- with pm .Model () as m :
606
- pass
607
-
608
- with pytest .raises (ParameterValueError ):
609
- aesara .function ([], bound )()
610
-
611
- m .check_bounds = False
612
- with m :
613
- assert np .all (compile_pymc ([], bound )() == 1 )
614
-
615
- m .check_bounds = True
616
- with m :
617
- assert np .all (compile_pymc ([], bound )() == - np .inf )
618
-
619
-
620
- def test_compile_pymc_sets_rng_updates ():
621
- rng = aesara .shared (np .random .default_rng (0 ))
622
- x = pm .Normal .dist (rng = rng )
623
- assert x .owner .inputs [0 ] is rng
624
- f = compile_pymc ([], x )
625
- assert not np .isclose (f (), f ())
626
-
627
- # Check that update was not done inplace
628
- assert not hasattr (rng , "default_update" )
629
- f = aesara .function ([], x )
630
- assert f () == f ()
631
-
632
-
633
- def test_compile_pymc_with_updates ():
634
- x = aesara .shared (0 )
635
- f = compile_pymc ([], x , updates = {x : x + 1 })
636
- assert f () == 0
637
- assert f () == 1
638
-
639
-
640
- def test_compile_pymc_missing_default_explicit_updates ():
641
- rng = aesara .shared (np .random .default_rng (0 ))
642
- x = pm .Normal .dist (rng = rng )
643
-
644
- # By default, compile_pymc should update the rng of x
645
- f = compile_pymc ([], x )
646
- assert f () != f ()
647
-
648
- # An explicit update should override the default_update, like aesara.function does
649
- # For testing purposes, we use an update that leaves the rng unchanged
650
- f = compile_pymc ([], x , updates = {rng : rng })
651
- assert f () == f ()
652
-
653
- # If we specify a custom default_update directly it should use that instead.
654
- rng .default_update = rng
655
- f = compile_pymc ([], x )
656
- assert f () == f ()
657
-
658
- # And again, it should be overridden by an explicit update
659
- f = compile_pymc ([], x , updates = {rng : x .owner .outputs [0 ]})
660
- assert f () != f ()
661
-
662
-
663
- def test_compile_pymc_updates_inputs ():
664
- """Test that compile_pymc does not include rngs updates of variables that are inputs
665
- or ancestors to inputs
666
- """
667
- x = at .random .normal ()
668
- y = at .random .normal (x )
669
- z = at .random .normal (y )
670
-
671
- for inputs , rvs_in_graph in (
672
- ([], 3 ),
673
- ([x ], 2 ),
674
- ([y ], 1 ),
675
- ([z ], 0 ),
676
- ([x , y ], 1 ),
677
- ([x , y , z ], 0 ),
678
- ):
679
- fn = compile_pymc (inputs , z , on_unused_input = "ignore" )
680
- fn_fgraph = fn .maker .fgraph
681
- # Each RV adds a shared input for its rng
682
- assert len (fn_fgraph .inputs ) == len (inputs ) + rvs_in_graph
683
- # If the output is an input, the graph has a DeepCopyOp
684
- assert len (fn_fgraph .apply_nodes ) == max (rvs_in_graph , 1 )
685
- # Each RV adds a shared output for its rng
686
- assert len (fn_fgraph .outputs ) == 1 + rvs_in_graph
599
+ class TestCompilePyMC :
600
+ def test_check_bounds_flag (self ):
601
+ """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc"""
602
+ logp = at .ones (3 )
603
+ cond = np .array ([1 , 0 , 1 ])
604
+ bound = check_parameters (logp , cond )
605
+
606
+ with pm .Model () as m :
607
+ pass
608
+
609
+ with pytest .raises (ParameterValueError ):
610
+ aesara .function ([], bound )()
611
+
612
+ m .check_bounds = False
613
+ with m :
614
+ assert np .all (compile_pymc ([], bound )() == 1 )
615
+
616
+ m .check_bounds = True
617
+ with m :
618
+ assert np .all (compile_pymc ([], bound )() == - np .inf )
619
+
620
+ def test_compile_pymc_sets_rng_updates (self ):
621
+ rng = aesara .shared (np .random .default_rng (0 ))
622
+ x = pm .Normal .dist (rng = rng )
623
+ assert x .owner .inputs [0 ] is rng
624
+ f = compile_pymc ([], x )
625
+ assert not np .isclose (f (), f ())
626
+
627
+ # Check that update was not done inplace
628
+ assert not hasattr (rng , "default_update" )
629
+ f = aesara .function ([], x )
630
+ assert f () == f ()
631
+
632
+ def test_compile_pymc_with_updates (self ):
633
+ x = aesara .shared (0 )
634
+ f = compile_pymc ([], x , updates = {x : x + 1 })
635
+ assert f () == 0
636
+ assert f () == 1
637
+
638
+ def test_compile_pymc_missing_default_explicit_updates (self ):
639
+ rng = aesara .shared (np .random .default_rng (0 ))
640
+ x = pm .Normal .dist (rng = rng )
641
+
642
+ # By default, compile_pymc should update the rng of x
643
+ f = compile_pymc ([], x )
644
+ assert f () != f ()
645
+
646
+ # An explicit update should override the default_update, like aesara.function does
647
+ # For testing purposes, we use an update that leaves the rng unchanged
648
+ f = compile_pymc ([], x , updates = {rng : rng })
649
+ assert f () == f ()
650
+
651
+ # If we specify a custom default_update directly it should use that instead.
652
+ rng .default_update = rng
653
+ f = compile_pymc ([], x )
654
+ assert f () == f ()
655
+
656
+ # And again, it should be overridden by an explicit update
657
+ f = compile_pymc ([], x , updates = {rng : x .owner .outputs [0 ]})
658
+ assert f () != f ()
659
+
660
+ def test_compile_pymc_updates_inputs (self ):
661
+ """Test that compile_pymc does not include rngs updates of variables that are inputs
662
+ or ancestors to inputs
663
+ """
664
+ x = at .random .normal ()
665
+ y = at .random .normal (x )
666
+ z = at .random .normal (y )
667
+
668
+ for inputs , rvs_in_graph in (
669
+ ([], 3 ),
670
+ ([x ], 2 ),
671
+ ([y ], 1 ),
672
+ ([z ], 0 ),
673
+ ([x , y ], 1 ),
674
+ ([x , y , z ], 0 ),
675
+ ):
676
+ fn = compile_pymc (inputs , z , on_unused_input = "ignore" )
677
+ fn_fgraph = fn .maker .fgraph
678
+ # Each RV adds a shared input for its rng
679
+ assert len (fn_fgraph .inputs ) == len (inputs ) + rvs_in_graph
680
+ # If the output is an input, the graph has a DeepCopyOp
681
+ assert len (fn_fgraph .apply_nodes ) == max (rvs_in_graph , 1 )
682
+ # Each RV adds a shared output for its rng
683
+ assert len (fn_fgraph .outputs ) == 1 + rvs_in_graph
0 commit comments