|
40 | 40 | import pytest
|
41 | 41 | import scipy.stats.distributions as sp
|
42 | 42 |
|
| 43 | +from pytensor import function |
43 | 44 | from pytensor.graph.basic import Variable, equal_computations
|
| 45 | +from pytensor.ifelse import ifelse |
44 | 46 | from pytensor.tensor.random.basic import CategoricalRV
|
45 | 47 | from pytensor.tensor.shape import shape_tuple
|
46 | 48 | from pytensor.tensor.subtensor import as_index_constant
|
@@ -942,3 +944,109 @@ def test_switch_mixture():
|
942 | 944 |
|
943 | 945 | np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0}))
|
944 | 946 | np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))
|
| 947 | + |
| 948 | + |
| 949 | +def test_ifelse_mixture_one_component(): |
| 950 | + if_rv = pt.random.bernoulli(0.5, name="if") |
| 951 | + scale_rv = pt.random.halfnormal(name="scale") |
| 952 | + comp_then = pt.random.normal(0, scale_rv, size=(2,), name="comp_then") |
| 953 | + comp_else = pt.random.halfnormal(0, scale_rv, size=(4,), name="comp_else") |
| 954 | + mix_rv = ifelse(if_rv, comp_then, comp_else, name="mix") |
| 955 | + |
| 956 | + if_vv = if_rv.clone() |
| 957 | + scale_vv = scale_rv.clone() |
| 958 | + mix_vv = mix_rv.clone() |
| 959 | + mix_logp = factorized_joint_logprob({if_rv: if_vv, scale_rv: scale_vv, mix_rv: mix_vv})[mix_vv] |
| 960 | + assert_no_rvs(mix_logp) |
| 961 | + |
| 962 | + fn = function([if_vv, scale_vv, mix_vv], mix_logp) |
| 963 | + scale_vv_test = 0.75 |
| 964 | + mix_vv_test = np.r_[1.0, 2.5] |
| 965 | + np.testing.assert_array_almost_equal( |
| 966 | + fn(1, scale_vv_test, mix_vv_test), |
| 967 | + sp.norm(0, scale_vv_test).logpdf(mix_vv_test), |
| 968 | + ) |
| 969 | + mix_vv_test = np.r_[1.0, 2.5, 3.5, 4.0] |
| 970 | + np.testing.assert_array_almost_equal( |
| 971 | + fn(0, scale_vv_test, mix_vv_test), sp.halfnorm(0, scale_vv_test).logpdf(mix_vv_test) |
| 972 | + ) |
| 973 | + |
| 974 | + |
| 975 | +def test_ifelse_mixture_multiple_components(): |
| 976 | + rng = np.random.default_rng(968) |
| 977 | + |
| 978 | + if_var = pt.scalar("if_var", dtype="bool") |
| 979 | + comp_then1 = pt.random.normal(size=(2,), name="comp_true1") |
| 980 | + comp_then2 = pt.random.normal(comp_then1, size=(2, 2), name="comp_then2") |
| 981 | + comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1") |
| 982 | + comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2") |
| 983 | + |
| 984 | + mix_rv1, mix_rv2 = ifelse( |
| 985 | + if_var, [comp_then1, comp_then2], [comp_else1, comp_else2], name="mix" |
| 986 | + ) |
| 987 | + mix_vv1 = mix_rv1.clone() |
| 988 | + mix_vv2 = mix_rv2.clone() |
| 989 | + mix_logp1, mix_logp2 = factorized_joint_logprob({mix_rv1: mix_vv1, mix_rv2: mix_vv2}).values() |
| 990 | + assert_no_rvs(mix_logp1) |
| 991 | + assert_no_rvs(mix_logp2) |
| 992 | + |
| 993 | + fn = function([if_var, mix_vv1, mix_vv2], mix_logp1.sum() + mix_logp2.sum()) |
| 994 | + mix_vv1_test = np.abs(rng.normal(size=(2,))) |
| 995 | + mix_vv2_test = np.abs(rng.normal(size=(2, 2))) |
| 996 | + np.testing.assert_almost_equal( |
| 997 | + fn(True, mix_vv1_test, mix_vv2_test), |
| 998 | + sp.norm(0, 1).logpdf(mix_vv1_test).sum() |
| 999 | + + sp.norm(mix_vv1_test, 1).logpdf(mix_vv2_test).sum(), |
| 1000 | + ) |
| 1001 | + mix_vv1_test = np.abs(rng.normal(size=(4,))) |
| 1002 | + mix_vv2_test = np.abs(rng.normal(size=(4, 4))) |
| 1003 | + np.testing.assert_almost_equal( |
| 1004 | + fn(False, mix_vv1_test, mix_vv2_test), |
| 1005 | + sp.halfnorm(0, 1).logpdf(mix_vv1_test).sum() + sp.halfnorm(0, 1).logpdf(mix_vv2_test).sum(), |
| 1006 | + ) |
| 1007 | + |
| 1008 | + |
| 1009 | +def test_ifelse_mixture_shared_component(): |
| 1010 | + rng = np.random.default_rng(1009) |
| 1011 | + |
| 1012 | + if_var = pt.scalar("if_var", dtype="bool") |
| 1013 | + outer_rv = pt.random.normal(name="outer") |
| 1014 | + # comp_shared need not be an output of ifelse at all, |
| 1015 | + # but since we allow arbitrary graphs we test it works as expected. |
| 1016 | + comp_shared = pt.random.normal(size=(2,), name="comp_shared") |
| 1017 | + comp_then = outer_rv + pt.random.normal(comp_shared, 1, size=(4, 2), name="comp_then") |
| 1018 | + comp_else = outer_rv + pt.random.normal(comp_shared, 10, size=(8, 2), name="comp_else") |
| 1019 | + shared_rv, mix_rv = ifelse( |
| 1020 | + if_var, [comp_shared, comp_then], [comp_shared, comp_else], name="mix" |
| 1021 | + ) |
| 1022 | + |
| 1023 | + outer_vv = outer_rv.clone() |
| 1024 | + shared_vv = shared_rv.clone() |
| 1025 | + mix_vv = mix_rv.clone() |
| 1026 | + outer_logp, mix_logp1, mix_logp2 = factorized_joint_logprob( |
| 1027 | + {outer_rv: outer_vv, shared_rv: shared_vv, mix_rv: mix_vv} |
| 1028 | + ).values() |
| 1029 | + assert_no_rvs(outer_logp) |
| 1030 | + assert_no_rvs(mix_logp1) |
| 1031 | + assert_no_rvs(mix_logp2) |
| 1032 | + |
| 1033 | + fn = function([if_var, outer_vv, shared_vv, mix_vv], mix_logp1.sum() + mix_logp2.sum()) |
| 1034 | + outer_vv_test = rng.normal() |
| 1035 | + shared_vv_test = rng.normal(size=(2,)) |
| 1036 | + mix_vv_test = rng.normal(size=(4, 2)) |
| 1037 | + np.testing.assert_almost_equal( |
| 1038 | + fn(True, outer_vv_test, shared_vv_test, mix_vv_test), |
| 1039 | + ( |
| 1040 | + sp.norm(0, 1).logpdf(shared_vv_test).sum() |
| 1041 | + + sp.norm(outer_vv_test + shared_vv_test, 1).logpdf(mix_vv_test).sum() |
| 1042 | + ), |
| 1043 | + ) |
| 1044 | + mix_vv_test = rng.normal(size=(8, 2)) |
| 1045 | + np.testing.assert_almost_equal( |
| 1046 | + fn(False, outer_vv_test, shared_vv_test, mix_vv_test), |
| 1047 | + ( |
| 1048 | + sp.norm(0, 1).logpdf(shared_vv_test).sum() |
| 1049 | + + sp.norm(outer_vv_test + shared_vv_test, 10).logpdf(mix_vv_test).sum() |
| 1050 | + ), |
| 1051 | + decimal=6, |
| 1052 | + ) |
0 commit comments