@@ -805,94 +805,90 @@ def sample_fn(rng, size, dtype, *parameters):
805
805
compare_jax_and_py (fgraph , [])
806
806
807
807
808
- def test_random_concrete_shape ():
809
- """JAX should compile when a `RandomVariable` is passed a concrete shape.
810
-
811
- There are three quantities that JAX considers as concrete:
812
- 1. Constants known at compile time;
813
- 2. The shape of an array.
814
- 3. `static_argnums` parameters
815
- This test makes sure that graphs with `RandomVariable`s compile when the
816
- `size` parameter satisfies either of these criteria.
817
-
818
- """
819
- rng = shared (np .random .default_rng (123 ))
820
- x_pt = pt .dmatrix ()
821
- out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
822
- jax_fn = compile_random_function ([x_pt ], out )
823
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
824
-
825
-
826
- def test_random_concrete_shape_from_param ():
827
- rng = shared (np .random .default_rng (123 ))
828
- x_pt = pt .dmatrix ()
829
- out = pt .random .normal (x_pt , 1 , rng = rng )
830
- jax_fn = compile_random_function ([x_pt ], out )
831
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
832
-
833
-
834
- def test_random_concrete_shape_subtensor ():
835
- """JAX should compile when a concrete value is passed for the `size` parameter.
836
-
837
- This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
838
- inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
839
- inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
840
- rewrite.
841
-
842
- JAX does not accept scalars as `size` or `shape` arguments, so this is a
843
- slight improvement over their API.
844
-
845
- """
846
- rng = shared (np .random .default_rng (123 ))
847
- x_pt = pt .dmatrix ()
848
- out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
849
- jax_fn = compile_random_function ([x_pt ], out )
850
- assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
851
-
852
-
853
- def test_random_concrete_shape_subtensor_tuple ():
854
- """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
855
-
856
- This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
857
- inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
858
- scalar inputs into tuples of concrete values using the
859
- `jax_size_parameter_as_tuple` rewrite.
860
-
861
- """
862
- rng = shared (np .random .default_rng (123 ))
863
- x_pt = pt .dmatrix ()
864
- out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
865
- jax_fn = compile_random_function ([x_pt ], out )
866
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
867
-
868
-
869
- @pytest .mark .xfail (
870
- reason = "`size_pt` should be specified as a static argument" , strict = True
871
- )
872
- def test_random_concrete_shape_graph_input ():
873
- rng = shared (np .random .default_rng (123 ))
874
- size_pt = pt .scalar ()
875
- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
876
- jax_fn = compile_random_function ([size_pt ], out )
877
- assert jax_fn (10 ).shape == (10 ,)
878
-
879
-
880
- def test_constant_shape_after_graph_rewriting ():
881
- size = pt .vector ("size" , shape = (2 ,), dtype = int )
882
- x = pt .random .normal (size = size )
883
- assert x .type .shape == (None , None )
884
-
885
- with pytest .raises (TypeError ):
886
- compile_random_function ([size ], x )([2 , 5 ])
887
-
888
- # Rebuild with strict=False so output type is not updated
889
- # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
890
- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
891
- assert new_x .type .shape == (None , None )
892
- assert compile_random_function ([], new_x )().shape == (2 , 5 )
893
-
894
- # Rebuild with strict=True, so output type is updated
895
- # This uses a different path in the dispatch implementation
896
- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
897
- assert new_x .type .shape == (2 , 5 )
898
- assert compile_random_function ([], new_x )().shape == (2 , 5 )
808
+ class TestRandomShapeInputs :
809
+ def test_random_concrete_shape (self ):
810
+ """JAX should compile when a `RandomVariable` is passed a concrete shape.
811
+
812
+ There are three quantities that JAX considers as concrete:
813
+ 1. Constants known at compile time;
814
+ 2. The shape of an array.
815
+ 3. `static_argnums` parameters
816
+ This test makes sure that graphs with `RandomVariable`s compile when the
817
+ `size` parameter satisfies either of these criteria.
818
+
819
+ """
820
+ rng = shared (np .random .default_rng (123 ))
821
+ x_pt = pt .dmatrix ()
822
+ out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
823
+ jax_fn = compile_random_function ([x_pt ], out )
824
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
825
+
826
+ def test_random_concrete_shape_from_param (self ):
827
+ rng = shared (np .random .default_rng (123 ))
828
+ x_pt = pt .dmatrix ()
829
+ out = pt .random .normal (x_pt , 1 , rng = rng )
830
+ jax_fn = compile_random_function ([x_pt ], out )
831
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
832
+
833
+ def test_random_concrete_shape_subtensor (self ):
834
+ """JAX should compile when a concrete value is passed for the `size` parameter.
835
+
836
+ This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
837
+ inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
838
+ inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
839
+ rewrite.
840
+
841
+ JAX does not accept scalars as `size` or `shape` arguments, so this is a
842
+ slight improvement over their API.
843
+
844
+ """
845
+ rng = shared (np .random .default_rng (123 ))
846
+ x_pt = pt .dmatrix ()
847
+ out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
848
+ jax_fn = compile_random_function ([x_pt ], out )
849
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
850
+
851
+ def test_random_concrete_shape_subtensor_tuple (self ):
852
+ """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
853
+
854
+ This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
855
+ inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
856
+ scalar inputs into tuples of concrete values using the
857
+ `jax_size_parameter_as_tuple` rewrite.
858
+
859
+ """
860
+ rng = shared (np .random .default_rng (123 ))
861
+ x_pt = pt .dmatrix ()
862
+ out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
863
+ jax_fn = compile_random_function ([x_pt ], out )
864
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
865
+
866
+ @pytest .mark .xfail (
867
+ reason = "`size_pt` should be specified as a static argument" , strict = True
868
+ )
869
+ def test_random_concrete_shape_graph_input (self ):
870
+ rng = shared (np .random .default_rng (123 ))
871
+ size_pt = pt .scalar ()
872
+ out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
873
+ jax_fn = compile_random_function ([size_pt ], out )
874
+ assert jax_fn (10 ).shape == (10 ,)
875
+
876
+ def test_constant_shape_after_graph_rewriting (self ):
877
+ size = pt .vector ("size" , shape = (2 ,), dtype = int )
878
+ x = pt .random .normal (size = size )
879
+ assert x .type .shape == (None , None )
880
+
881
+ with pytest .raises (TypeError ):
882
+ compile_random_function ([size ], x )([2 , 5 ])
883
+
884
+ # Rebuild with strict=False so output type is not updated
885
+ # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
886
+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
887
+ assert new_x .type .shape == (None , None )
888
+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
889
+
890
+ # Rebuild with strict=True, so output type is updated
891
+ # This uses a different path in the dispatch implementation
892
+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
893
+ assert new_x .type .shape == (2 , 5 )
894
+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
0 commit comments