@@ -836,94 +836,90 @@ def sample_fn(rng, size, dtype, *parameters):
836
836
compare_jax_and_py ([], [out ], [])
837
837
838
838
839
- def test_random_concrete_shape ():
840
- """JAX should compile when a `RandomVariable` is passed a concrete shape.
841
-
842
- There are three quantities that JAX considers as concrete:
843
- 1. Constants known at compile time;
844
- 2. The shape of an array.
845
- 3. `static_argnums` parameters
846
- This test makes sure that graphs with `RandomVariable`s compile when the
847
- `size` parameter satisfies either of these criteria.
848
-
849
- """
850
- rng = shared (np .random .default_rng (123 ))
851
- x_pt = pt .dmatrix ()
852
- out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
853
- jax_fn = compile_random_function ([x_pt ], out )
854
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
855
-
856
-
857
- def test_random_concrete_shape_from_param ():
858
- rng = shared (np .random .default_rng (123 ))
859
- x_pt = pt .dmatrix ()
860
- out = pt .random .normal (x_pt , 1 , rng = rng )
861
- jax_fn = compile_random_function ([x_pt ], out )
862
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
863
-
864
-
865
- def test_random_concrete_shape_subtensor ():
866
- """JAX should compile when a concrete value is passed for the `size` parameter.
867
-
868
- This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
869
- inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
870
- inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
871
- rewrite.
872
-
873
- JAX does not accept scalars as `size` or `shape` arguments, so this is a
874
- slight improvement over their API.
875
-
876
- """
877
- rng = shared (np .random .default_rng (123 ))
878
- x_pt = pt .dmatrix ()
879
- out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
880
- jax_fn = compile_random_function ([x_pt ], out )
881
- assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
882
-
883
-
884
- def test_random_concrete_shape_subtensor_tuple ():
885
- """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
886
-
887
- This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
888
- inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
889
- scalar inputs into tuples of concrete values using the
890
- `jax_size_parameter_as_tuple` rewrite.
891
-
892
- """
893
- rng = shared (np .random .default_rng (123 ))
894
- x_pt = pt .dmatrix ()
895
- out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
896
- jax_fn = compile_random_function ([x_pt ], out )
897
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
898
-
899
-
900
- @pytest .mark .xfail (
901
- reason = "`size_pt` should be specified as a static argument" , strict = True
902
- )
903
- def test_random_concrete_shape_graph_input ():
904
- rng = shared (np .random .default_rng (123 ))
905
- size_pt = pt .scalar ()
906
- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
907
- jax_fn = compile_random_function ([size_pt ], out )
908
- assert jax_fn (10 ).shape == (10 ,)
909
-
910
-
911
- def test_constant_shape_after_graph_rewriting ():
912
- size = pt .vector ("size" , shape = (2 ,), dtype = int )
913
- x = pt .random .normal (size = size )
914
- assert x .type .shape == (None , None )
915
-
916
- with pytest .raises (TypeError ):
917
- compile_random_function ([size ], x )([2 , 5 ])
918
-
919
- # Rebuild with strict=False so output type is not updated
920
- # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
921
- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
922
- assert new_x .type .shape == (None , None )
923
- assert compile_random_function ([], new_x )().shape == (2 , 5 )
924
-
925
- # Rebuild with strict=True, so output type is updated
926
- # This uses a different path in the dispatch implementation
927
- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
928
- assert new_x .type .shape == (2 , 5 )
929
- assert compile_random_function ([], new_x )().shape == (2 , 5 )
839
+ class TestRandomShapeInputs :
840
+ def test_random_concrete_shape (self ):
841
+ """JAX should compile when a `RandomVariable` is passed a concrete shape.
842
+
843
+ There are three quantities that JAX considers as concrete:
844
+ 1. Constants known at compile time;
845
+ 2. The shape of an array.
846
+ 3. `static_argnums` parameters
847
+ This test makes sure that graphs with `RandomVariable`s compile when the
848
+ `size` parameter satisfies either of these criteria.
849
+
850
+ """
851
+ rng = shared (np .random .default_rng (123 ))
852
+ x_pt = pt .dmatrix ()
853
+ out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
854
+ jax_fn = compile_random_function ([x_pt ], out )
855
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
856
+
857
+ def test_random_concrete_shape_from_param (self ):
858
+ rng = shared (np .random .default_rng (123 ))
859
+ x_pt = pt .dmatrix ()
860
+ out = pt .random .normal (x_pt , 1 , rng = rng )
861
+ jax_fn = compile_random_function ([x_pt ], out )
862
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
863
+
864
+ def test_random_concrete_shape_subtensor (self ):
865
+ """JAX should compile when a concrete value is passed for the `size` parameter.
866
+
867
+ This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
868
+ inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
869
+ inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
870
+ rewrite.
871
+
872
+ JAX does not accept scalars as `size` or `shape` arguments, so this is a
873
+ slight improvement over their API.
874
+
875
+ """
876
+ rng = shared (np .random .default_rng (123 ))
877
+ x_pt = pt .dmatrix ()
878
+ out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
879
+ jax_fn = compile_random_function ([x_pt ], out )
880
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
881
+
882
+ def test_random_concrete_shape_subtensor_tuple (self ):
883
+ """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
884
+
885
+ This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
886
+ inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
887
+ scalar inputs into tuples of concrete values using the
888
+ `jax_size_parameter_as_tuple` rewrite.
889
+
890
+ """
891
+ rng = shared (np .random .default_rng (123 ))
892
+ x_pt = pt .dmatrix ()
893
+ out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
894
+ jax_fn = compile_random_function ([x_pt ], out )
895
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
896
+
897
+ @pytest .mark .xfail (
898
+ reason = "`size_pt` should be specified as a static argument" , strict = True
899
+ )
900
+ def test_random_concrete_shape_graph_input (self ):
901
+ rng = shared (np .random .default_rng (123 ))
902
+ size_pt = pt .scalar ()
903
+ out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
904
+ jax_fn = compile_random_function ([size_pt ], out )
905
+ assert jax_fn (10 ).shape == (10 ,)
906
+
907
+ def test_constant_shape_after_graph_rewriting (self ):
908
+ size = pt .vector ("size" , shape = (2 ,), dtype = int )
909
+ x = pt .random .normal (size = size )
910
+ assert x .type .shape == (None , None )
911
+
912
+ with pytest .raises (TypeError ):
913
+ compile_random_function ([size ], x )([2 , 5 ])
914
+
915
+ # Rebuild with strict=False so output type is not updated
916
+ # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
917
+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
918
+ assert new_x .type .shape == (None , None )
919
+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
920
+
921
+ # Rebuild with strict=True, so output type is updated
922
+ # This uses a different path in the dispatch implementation
923
+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
924
+ assert new_x .type .shape == (2 , 5 )
925
+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
0 commit comments