@@ -937,38 +937,46 @@ def test_infer_static_shape():
937
937
assert static_shape == (1 ,)
938
938
939
939
940
- # This is slow for the ('int8', 3) version.
941
- def test_eye ():
942
- def check (dtype , N , M_ = None , k = 0 ):
943
- # PyTensor does not accept None as a tensor.
944
- # So we must use a real value.
945
- M = M_
946
- # Currently DebugMode does not support None as inputs even if this is
947
- # allowed.
948
- if M is None and config .mode in ["DebugMode" , "DEBUG_MODE" ]:
949
- M = N
950
- N_symb = iscalar ()
951
- M_symb = iscalar ()
952
- k_symb = iscalar ()
953
- f = function ([N_symb , M_symb , k_symb ], eye (N_symb , M_symb , k_symb , dtype = dtype ))
954
- result = f (N , M , k )
955
- assert np .allclose (result , np .eye (N , M_ , k , dtype = dtype ))
956
- assert result .dtype == np .dtype (dtype )
940
+ class TestEye :
941
+ # This is slow for the ('int8', 3) version.
942
+ def test_basic (self ):
943
+ def check (dtype , N , M_ = None , k = 0 ):
944
+ # PyTensor does not accept None as a tensor.
945
+ # So we must use a real value.
946
+ M = M_
947
+ # Currently DebugMode does not support None as inputs even if this is
948
+ # allowed.
949
+ if M is None and config .mode in ["DebugMode" , "DEBUG_MODE" ]:
950
+ M = N
951
+ N_symb = iscalar ()
952
+ M_symb = iscalar ()
953
+ k_symb = iscalar ()
954
+ f = function (
955
+ [N_symb , M_symb , k_symb ], eye (N_symb , M_symb , k_symb , dtype = dtype )
956
+ )
957
+ result = f (N , M , k )
958
+ assert np .allclose (result , np .eye (N , M_ , k , dtype = dtype ))
959
+ assert result .dtype == np .dtype (dtype )
957
960
958
- for dtype in ALL_DTYPES :
959
- check (dtype , 3 )
960
- # M != N, k = 0
961
- check (dtype , 3 , 5 )
962
- check (dtype , 5 , 3 )
963
- # N == M, k != 0
964
- check (dtype , 3 , 3 , 1 )
965
- check (dtype , 3 , 3 , - 1 )
966
- # N < M, k != 0
967
- check (dtype , 3 , 5 , 1 )
968
- check (dtype , 3 , 5 , - 1 )
969
- # N > M, k != 0
970
- check (dtype , 5 , 3 , 1 )
971
- check (dtype , 5 , 3 , - 1 )
961
+ for dtype in ALL_DTYPES :
962
+ check (dtype , 3 )
963
+ # M != N, k = 0
964
+ check (dtype , 3 , 5 )
965
+ check (dtype , 5 , 3 )
966
+ # N == M, k != 0
967
+ check (dtype , 3 , 3 , 1 )
968
+ check (dtype , 3 , 3 , - 1 )
969
+ # N < M, k != 0
970
+ check (dtype , 3 , 5 , 1 )
971
+ check (dtype , 3 , 5 , - 1 )
972
+ # N > M, k != 0
973
+ check (dtype , 5 , 3 , 1 )
974
+ check (dtype , 5 , 3 , - 1 )
975
+
976
+ def test_static_output_type (self ):
977
+ l = lscalar ("l" )
978
+ assert eye (5 , 3 , l ).type .shape == (5 , 3 )
979
+ assert eye (1 , l , 3 ).type .shape == (1 , None )
972
980
973
981
974
982
class TestTriangle :
0 commit comments