Skip to content

Commit 0334994

Browse files
committed
Derive logprob of IfElse graphs
1 parent 239da11 commit 0334994

File tree

2 files changed

+179
-2
lines changed

2 files changed

+179
-2
lines changed

pymc/logprob/mixture.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
node_rewriter,
4848
pre_greedy_node_rewriter,
4949
)
50-
from pytensor.ifelse import ifelse
50+
from pytensor.ifelse import IfElse, ifelse
5151
from pytensor.scalar.basic import Switch
5252
from pytensor.tensor.basic import Join, MakeVector
5353
from pytensor.tensor.elemwise import Elemwise
@@ -73,10 +73,11 @@
7373
from pymc.logprob.rewriting import (
7474
local_lift_DiracDelta,
7575
logprob_rewrites_db,
76+
measurable_ir_rewrites_db,
7677
subtensor_ops,
7778
)
7879
from pymc.logprob.tensor import naive_bcast_rv_lift
79-
from pymc.logprob.utils import ignore_logprob
80+
from pymc.logprob.utils import ignore_logprob, ignore_logprob_multiple_vars
8081

8182

8283
def is_newaxis(x):
@@ -483,3 +484,71 @@ def logprob_MixtureRV(
483484
"basic",
484485
"mixture",
485486
)
487+
488+
489+
class MeasurableIfElse(IfElse):
490+
"""Measurable subclass of IfElse operator."""
491+
492+
493+
MeasurableVariable.register(MeasurableIfElse)
494+
495+
496+
@node_rewriter([IfElse])
497+
def find_measurable_ifelse_mixture(fgraph, node):
498+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
499+
500+
if rv_map_feature is None:
501+
return None # pragma: no cover
502+
503+
if isinstance(node.op, MeasurableIfElse):
504+
return None
505+
506+
# Check if all components are unvalued measuarable variables
507+
if_var, *base_rvs = node.inputs
508+
509+
if not all(
510+
(
511+
rv.owner is not None
512+
and isinstance(rv.owner.op, MeasurableVariable)
513+
and rv not in rv_map_feature.rv_values
514+
)
515+
for rv in base_rvs
516+
):
517+
return None # pragma: no cover
518+
519+
unmeasurable_base_rvs = ignore_logprob_multiple_vars(base_rvs, rv_map_feature.rv_values)
520+
521+
return MeasurableIfElse(n_outs=node.op.n_outs).make_node(if_var, *unmeasurable_base_rvs).outputs
522+
523+
524+
measurable_ir_rewrites_db.register(
525+
"find_measurable_ifelse_mixture",
526+
find_measurable_ifelse_mixture,
527+
"basic",
528+
"mixture",
529+
)
530+
531+
532+
@_logprob.register(MeasurableIfElse)
533+
def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
534+
"""Compute the log-likelihood graph for an `IfElse`."""
535+
from pymc.pytensorf import replace_rvs_by_values
536+
537+
assert len(values) * 2 == len(base_rvs)
538+
539+
rvs_to_values_then = {then_rv: value for then_rv, value in zip(base_rvs[: len(values)], values)}
540+
rvs_to_values_else = {else_rv: value for else_rv, value in zip(base_rvs[len(values) :], values)}
541+
542+
logps_then = [
543+
logprob(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items()
544+
]
545+
logps_else = [
546+
logprob(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items()
547+
]
548+
549+
# If the multiple variables depend on each other, we have to replace them
550+
# by the respective values
551+
logps_then = replace_rvs_by_values(logps_then, rvs_to_values=rvs_to_values_then)
552+
logps_else = replace_rvs_by_values(logps_else, rvs_to_values=rvs_to_values_else)
553+
554+
return ifelse(if_var, logps_then, logps_else)

tests/logprob/test_mixture.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
import pytest
4141
import scipy.stats.distributions as sp
4242

43+
from pytensor import function
4344
from pytensor.graph.basic import Variable, equal_computations
45+
from pytensor.ifelse import ifelse
4446
from pytensor.tensor.random.basic import CategoricalRV
4547
from pytensor.tensor.shape import shape_tuple
4648
from pytensor.tensor.subtensor import as_index_constant
@@ -942,3 +944,109 @@ def test_switch_mixture():
942944

943945
np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0}))
944946
np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))
947+
948+
949+
def test_ifelse_mixture_one_component():
950+
if_rv = pt.random.bernoulli(0.5, name="if")
951+
scale_rv = pt.random.halfnormal(name="scale")
952+
comp_then = pt.random.normal(0, scale_rv, size=(2,), name="comp_then")
953+
comp_else = pt.random.halfnormal(0, scale_rv, size=(4,), name="comp_else")
954+
mix_rv = ifelse(if_rv, comp_then, comp_else, name="mix")
955+
956+
if_vv = if_rv.clone()
957+
scale_vv = scale_rv.clone()
958+
mix_vv = mix_rv.clone()
959+
mix_logp = factorized_joint_logprob({if_rv: if_vv, scale_rv: scale_vv, mix_rv: mix_vv})[mix_vv]
960+
assert_no_rvs(mix_logp)
961+
962+
fn = function([if_vv, scale_vv, mix_vv], mix_logp)
963+
scale_vv_test = 0.75
964+
mix_vv_test = np.r_[1.0, 2.5]
965+
np.testing.assert_array_almost_equal(
966+
fn(1, scale_vv_test, mix_vv_test),
967+
sp.norm(0, scale_vv_test).logpdf(mix_vv_test),
968+
)
969+
mix_vv_test = np.r_[1.0, 2.5, 3.5, 4.0]
970+
np.testing.assert_array_almost_equal(
971+
fn(0, scale_vv_test, mix_vv_test), sp.halfnorm(0, scale_vv_test).logpdf(mix_vv_test)
972+
)
973+
974+
975+
def test_ifelse_mixture_multiple_components():
976+
rng = np.random.default_rng(968)
977+
978+
if_var = pt.scalar("if_var", dtype="bool")
979+
comp_then1 = pt.random.normal(size=(2,), name="comp_true1")
980+
comp_then2 = pt.random.normal(comp_then1, size=(2, 2), name="comp_then2")
981+
comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1")
982+
comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2")
983+
984+
mix_rv1, mix_rv2 = ifelse(
985+
if_var, [comp_then1, comp_then2], [comp_else1, comp_else2], name="mix"
986+
)
987+
mix_vv1 = mix_rv1.clone()
988+
mix_vv2 = mix_rv2.clone()
989+
mix_logp1, mix_logp2 = factorized_joint_logprob({mix_rv1: mix_vv1, mix_rv2: mix_vv2}).values()
990+
assert_no_rvs(mix_logp1)
991+
assert_no_rvs(mix_logp2)
992+
993+
fn = function([if_var, mix_vv1, mix_vv2], mix_logp1.sum() + mix_logp2.sum())
994+
mix_vv1_test = np.abs(rng.normal(size=(2,)))
995+
mix_vv2_test = np.abs(rng.normal(size=(2, 2)))
996+
np.testing.assert_almost_equal(
997+
fn(True, mix_vv1_test, mix_vv2_test),
998+
sp.norm(0, 1).logpdf(mix_vv1_test).sum()
999+
+ sp.norm(mix_vv1_test, 1).logpdf(mix_vv2_test).sum(),
1000+
)
1001+
mix_vv1_test = np.abs(rng.normal(size=(4,)))
1002+
mix_vv2_test = np.abs(rng.normal(size=(4, 4)))
1003+
np.testing.assert_almost_equal(
1004+
fn(False, mix_vv1_test, mix_vv2_test),
1005+
sp.halfnorm(0, 1).logpdf(mix_vv1_test).sum() + sp.halfnorm(0, 1).logpdf(mix_vv2_test).sum(),
1006+
)
1007+
1008+
1009+
def test_ifelse_mixture_shared_component():
1010+
rng = np.random.default_rng(1009)
1011+
1012+
if_var = pt.scalar("if_var", dtype="bool")
1013+
outer_rv = pt.random.normal(name="outer")
1014+
# comp_shared need not be an output of ifelse at all,
1015+
# but since we allow arbitrary graphs we test it works as expected.
1016+
comp_shared = pt.random.normal(size=(2,), name="comp_shared")
1017+
comp_then = outer_rv + pt.random.normal(comp_shared, 1, size=(4, 2), name="comp_then")
1018+
comp_else = outer_rv + pt.random.normal(comp_shared, 10, size=(8, 2), name="comp_else")
1019+
shared_rv, mix_rv = ifelse(
1020+
if_var, [comp_shared, comp_then], [comp_shared, comp_else], name="mix"
1021+
)
1022+
1023+
outer_vv = outer_rv.clone()
1024+
shared_vv = shared_rv.clone()
1025+
mix_vv = mix_rv.clone()
1026+
outer_logp, mix_logp1, mix_logp2 = factorized_joint_logprob(
1027+
{outer_rv: outer_vv, shared_rv: shared_vv, mix_rv: mix_vv}
1028+
).values()
1029+
assert_no_rvs(outer_logp)
1030+
assert_no_rvs(mix_logp1)
1031+
assert_no_rvs(mix_logp2)
1032+
1033+
fn = function([if_var, outer_vv, shared_vv, mix_vv], mix_logp1.sum() + mix_logp2.sum())
1034+
outer_vv_test = rng.normal()
1035+
shared_vv_test = rng.normal(size=(2,))
1036+
mix_vv_test = rng.normal(size=(4, 2))
1037+
np.testing.assert_almost_equal(
1038+
fn(True, outer_vv_test, shared_vv_test, mix_vv_test),
1039+
(
1040+
sp.norm(0, 1).logpdf(shared_vv_test).sum()
1041+
+ sp.norm(outer_vv_test + shared_vv_test, 1).logpdf(mix_vv_test).sum()
1042+
),
1043+
)
1044+
mix_vv_test = rng.normal(size=(8, 2))
1045+
np.testing.assert_almost_equal(
1046+
fn(False, outer_vv_test, shared_vv_test, mix_vv_test),
1047+
(
1048+
sp.norm(0, 1).logpdf(shared_vv_test).sum()
1049+
+ sp.norm(outer_vv_test + shared_vv_test, 10).logpdf(mix_vv_test).sum()
1050+
),
1051+
decimal=6,
1052+
)

0 commit comments

Comments
 (0)