@@ -781,7 +781,7 @@ def test_det_kronecker_rewrite():
781
781
a , b = pt .dmatrices ("a" , "b" )
782
782
kron_prod = pt .linalg .kron (a , b )
783
783
det_output = pt .linalg .det (kron_prod )
784
- f_rewritten = function ([kron_prod ], [det_output ], mode = "FAST_RUN" )
784
+ f_rewritten = function ([a , b ], [det_output ], mode = "FAST_RUN" )
785
785
786
786
# Rewrite Test
787
787
nodes = f_rewritten .maker .fgraph .apply_nodes
@@ -791,7 +791,7 @@ def test_det_kronecker_rewrite():
791
791
a_test , b_test = np .random .rand (2 , 20 , 20 )
792
792
kron_prod_test = np .kron (a_test , b_test )
793
793
det_output_test = np .linalg .det (kron_prod_test )
794
- rewritten_det_val = f_rewritten (kron_prod_test )
794
+ rewritten_det_val = f_rewritten (a_test , b_test )
795
795
assert_allclose (
796
796
det_output_test ,
797
797
rewritten_det_val ,
@@ -800,6 +800,35 @@ def test_det_kronecker_rewrite():
800
800
)
801
801
802
802
803
+ def test_slogdet_kronecker_rewrite ():
804
+ a , b = pt .dmatrices ("a" , "b" )
805
+ kron_prod = pt .linalg .kron (a , b )
806
+ sign_output , logdet_output = pt .linalg .slogdet (kron_prod )
807
+ f_rewritten = function ([a , b ], [sign_output , logdet_output ], mode = "FAST_RUN" )
808
+
809
+ # Rewrite Test
810
+ nodes = f_rewritten .maker .fgraph .apply_nodes
811
+ assert not any (isinstance (node .op , KroneckerProduct ) for node in nodes )
812
+
813
+ # Value Test
814
+ a_test , b_test = np .random .rand (2 , 20 , 20 )
815
+ kron_prod_test = np .kron (a_test , b_test )
816
+ sign_output_test , logdet_output_test = np .linalg .slogdet (kron_prod_test )
817
+ rewritten_sign_val , rewritten_logdet_val = f_rewritten (a_test , b_test )
818
+ assert_allclose (
819
+ sign_output_test ,
820
+ rewritten_sign_val ,
821
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
822
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
823
+ )
824
+ assert_allclose (
825
+ logdet_output_test ,
826
+ rewritten_logdet_val ,
827
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
828
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
829
+ )
830
+
831
+
803
832
def test_cholesky_eye_rewrite ():
804
833
x = pt .eye (10 )
805
834
L = pt .linalg .cholesky (x )
@@ -904,20 +933,60 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
904
933
905
934
906
935
def test_slogdet_specialisation ():
907
- x = pt .dmatrix ("x" )
908
- det_x = pt .linalg .det (x )
909
- log_abs_det_x = pt .log (pt .abs (det_x ))
910
- sign_det_x = pt .sign (det_x )
936
+ x , a = pt .dmatrix ("x" ), np .random .rand (20 , 20 )
937
+ det_x , det_a = pt .linalg .det (x ), np .linalg .det (a )
938
+ log_abs_det_x , log_abs_det_a = pt .log (pt .abs (det_x )), np .log (np .abs (det_a ))
939
+ log_det_x , log_det_a = pt .log (det_x ), np .log (det_a )
940
+ sign_det_x , sign_det_a = pt .sign (det_x ), np .sign (det_a )
911
941
exp_det_x = pt .exp (det_x )
942
+ # REWRITE TESTS
912
943
# sign(det(x))
913
944
f = function ([x ], [sign_det_x ], mode = "FAST_RUN" )
914
945
nodes = f .maker .fgraph .apply_nodes
915
- assert any (isinstance (node .op , SLogDet ) for node in nodes )
946
+ assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
947
+ assert not any (isinstance (node .op , Det ) for node in nodes )
948
+ rw_sign_det_a = f (a )
949
+ assert_allclose (
950
+ sign_det_a ,
951
+ rw_sign_det_a ,
952
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
953
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
954
+ )
916
955
# log(abs(det(x)))
917
956
f = function ([x ], [log_abs_det_x ], mode = "FAST_RUN" )
918
957
nodes = f .maker .fgraph .apply_nodes
919
- assert any (isinstance (node .op , SLogDet ) for node in nodes )
958
+ assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
959
+ assert not any (isinstance (node .op , Det ) for node in nodes )
960
+ rw_log_abs_det_a = f (a )
961
+ assert_allclose (
962
+ log_abs_det_a ,
963
+ rw_log_abs_det_a ,
964
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
965
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
966
+ )
967
+ # log(det(x))
968
+ f = function ([x ], [log_det_x ], mode = "FAST_RUN" )
969
+ nodes = f .maker .fgraph .apply_nodes
970
+ assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
971
+ assert not any (isinstance (node .op , Det ) for node in nodes )
972
+ rw_log_det_a = f (a )
973
+ assert_allclose (
974
+ log_det_a ,
975
+ rw_log_det_a ,
976
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
977
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
978
+ )
979
+ # more than 1 valid function
980
+ f = function ([x ], [sign_det_x , log_abs_det_x ], mode = "FAST_RUN" )
981
+ nodes = f .maker .fgraph .apply_nodes
982
+ assert len ([node for node in nodes if isinstance (node .op , SLogDet )]) == 1
983
+ assert not any (isinstance (node .op , Det ) for node in nodes )
920
984
# other functions (rewrite shouldnt be applied to these)
985
+ # only invalid functions
921
986
f = function ([x ], [exp_det_x ], mode = "FAST_RUN" )
922
987
nodes = f .maker .fgraph .apply_nodes
923
988
assert not any (isinstance (node .op , SLogDet ) for node in nodes )
989
+ # invalid + valid function
990
+ f = function ([x ], [exp_det_x , sign_det_x ], mode = "FAST_RUN" )
991
+ nodes = f .maker .fgraph .apply_nodes
992
+ assert not any (isinstance (node .op , SLogDet ) for node in nodes )
0 commit comments