@@ -894,15 +894,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
894
894
jax_fn = compile_random_function ([x_pt ], out )
895
895
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
896
896
897
+ def test_random_scalar_shape_input (self ):
898
+ dim0 = pt .scalar ("dim0" , dtype = int )
899
+ dim1 = pt .scalar ("dim1" , dtype = int )
900
+
901
+ out = pt .random .normal (0 , 1 , size = dim0 )
902
+ jax_fn = compile_random_function ([dim0 ], out )
903
+ assert jax_fn (np .array (2 )).shape == (2 ,)
904
+ assert jax_fn (np .array (3 )).shape == (3 ,)
905
+
906
+ out = pt .random .normal (0 , 1 , size = [dim0 , dim1 ])
907
+ jax_fn = compile_random_function ([dim0 , dim1 ], out )
908
+ assert jax_fn (np .array (2 ), np .array (3 )).shape == (2 , 3 )
909
+ assert jax_fn (np .array (4 ), np .array (5 )).shape == (4 , 5 )
910
+
897
911
@pytest .mark .xfail (
898
- reason = "`size_pt` should be specified as a static argument" , strict = True
912
+ raises = TypeError , reason = "Cannot convert scalar input to integer"
899
913
)
900
- def test_random_concrete_shape_graph_input (self ):
901
- rng = shared (np .random .default_rng (123 ))
902
- size_pt = pt .scalar ()
903
- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
904
- jax_fn = compile_random_function ([size_pt ], out )
905
- assert jax_fn (10 ).shape == (10 ,)
914
+ def test_random_scalar_shape_input_not_supported (self ):
915
+ dim = pt .scalar ("dim" , dtype = int )
916
+ out1 = pt .random .normal (0 , 1 , size = dim )
917
+ # An operation that wouldn't work if we replaced 0d array by integer
918
+ out2 = dim [...].set (1 )
919
+ jax_fn = compile_random_function ([dim ], [out1 , out2 ])
920
+
921
+ res1 , res2 = jax_fn (np .array (2 ))
922
+ assert res1 .shape == (2 ,)
923
+ assert res2 == 1
924
+
925
+ @pytest .mark .xfail (
926
+ raises = TypeError , reason = "Cannot convert scalar input to integer"
927
+ )
928
+ def test_random_scalar_shape_input_not_supported2 (self ):
929
+ dim = pt .scalar ("dim" , dtype = int )
930
+ # This could theoretically be supported
931
+ # but would require knowing that * 2 is a safe operation for a python integer
932
+ out = pt .random .normal (0 , 1 , size = dim * 2 )
933
+ jax_fn = compile_random_function ([dim ], out )
934
+ assert jax_fn (np .array (2 )).shape == (4 ,)
935
+
936
+ @pytest .mark .xfail (
937
+ raises = TypeError , reason = "Cannot convert tensor input to shape tuple"
938
+ )
939
+ def test_random_vector_shape_graph_input (self ):
940
+ shape = pt .vector ("shape" , shape = (2 ,), dtype = int )
941
+ out = pt .random .normal (0 , 1 , size = shape )
942
+
943
+ jax_fn = compile_random_function ([shape ], out )
944
+ assert jax_fn (np .array ([2 , 3 ])).shape == (2 , 3 )
945
+ assert jax_fn (np .array ([4 , 5 ])).shape == (4 , 5 )
906
946
907
947
def test_constant_shape_after_graph_rewriting (self ):
908
948
size = pt .vector ("size" , shape = (2 ,), dtype = int )
@@ -912,13 +952,13 @@ def test_constant_shape_after_graph_rewriting(self):
912
952
with pytest .raises (TypeError ):
913
953
compile_random_function ([size ], x )([2 , 5 ])
914
954
915
- # Rebuild with strict=False so output type is not updated
955
+ # Rebuild with strict=True so output type is not updated
916
956
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
917
957
new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
918
958
assert new_x .type .shape == (None , None )
919
959
assert compile_random_function ([], new_x )().shape == (2 , 5 )
920
960
921
- # Rebuild with strict=True , so output type is updated
961
+ # Rebuild with strict=False , so output type is updated
922
962
# This uses a different path in the dispatch implementation
923
963
new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
924
964
assert new_x .type .shape == (2 , 5 )
0 commit comments