Skip to content

Commit 6169cf3

Browse files
committed
.WIP
1 parent 378dbe4 commit 6169cf3

File tree

4 files changed

+93
-56
lines changed

4 files changed

+93
-56
lines changed

pymc_experimental/model/marginal/distributions.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
3939
return (0, 1)
4040
elif isinstance(op, Categorical):
4141
[p_param] = dist_params
42-
return tuple(range(pt.get_vector_length(p_param)))
42+
[p_param_length] = constant_fold([p_param.shape[-1]])
43+
return tuple(range(p_param_length))
4344
elif isinstance(op, DiscreteUniform):
4445
lower, upper = constant_fold(dist_params)
4546
return tuple(np.arange(lower, upper + 1))
@@ -60,7 +61,7 @@ def _add_reduce_batch_dependent_logps(
6061
for dependent_logp in dependent_logps:
6162
dbcast = dependent_logp.type.broadcastable
6263
dim_diff = len(dbcast) - len(mbcast)
63-
mbcast_aligned = (True,) * dim_diff + mbcast
64+
mbcast_aligned = mbcast + (True,) * dim_diff
6465
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
6566
reduced_logps.append(dependent_logp.sum(vbcast_axis))
6667
return pt.add(*reduced_logps)
@@ -79,6 +80,8 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
7980
inner_rv_values = dict(zip(inner_rvs, values))
8081
marginalized_vv = marginalized_rv.clone()
8182
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
83+
print("")
84+
print("Inner conditional logp call >> ")
8285
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
8386

8487
# Reduce logp dimensions corresponding to broadcasted variables

pymc_experimental/model/marginal/marginal_model.py

+49-32
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from pymc.distributions.transforms import Chain
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
16-
from pymc.pytensorf import compile_pymc, constant_fold
16+
from pymc.pytensorf import compile_pymc, constant_fold, toposort_replace
1717
from pymc.util import RandomState, _get_seeds_per_chain, treedict
1818
from pytensor.graph import FunctionGraph, clone_replace
19+
from pytensor.graph.basic import truncated_graph_inputs, Constant, ancestors
1920
from pytensor.graph.replace import vectorize_graph
20-
from pytensor.tensor import TensorVariable
21+
from pytensor.tensor import TensorVariable, extract_constant
2122
from pytensor.tensor.special import log_softmax
2223

2324
__all__ = ["MarginalModel", "marginalize"]
@@ -544,52 +545,45 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
544545
if not dependent_rvs:
545546
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
546547

547-
ndim_supp = max({rv.owner.op.ndim_supp for rv in dependent_rvs})
548-
549548
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
550549
other_direct_rv_ancestors = [
551550
rv
552551
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
553552
if rv is not rv_to_marginalize
554553
]
555554

556-
# If the marginalized RV has multiple dimensions, check that graph between
557-
# marginalized RV and dependent RVs does not mix information from batch dimensions
558-
# (otherwise logp would require enuremating over all combinations of batch dimension values)
559-
if any(not bcast for bcast in rv_to_marginalize.type.broadcastable):
560-
# When there are batch dimensions, we call `batch_dims_subgraph` to make sure these are not mixed
561-
dependent_rvs_dims = subgraph_dim_connection(
555+
if all (rv_to_marginalize.type.broadcastable):
556+
ndim_supp = max([dependent_rv.type.ndim for dependent_rv in dependent_rvs])
557+
else:
558+
# If the marginalized RV has multiple dimensions, check that graph between
559+
# marginalized RV and dependent RVs does not mix information from batch dimensions
560+
# (otherwise logp would require enumerating over all combinations of batch dimension values)
561+
dependent_rvs_dim_connections = subgraph_dim_connection(
562562
rv_to_marginalize, other_direct_rv_ancestors, dependent_rvs
563563
)
564+
# dependent_rvs_dim_connections = subgraph_dim_connection(
565+
# rv_to_marginalize, other_inputs, dependent_rvs
566+
# )
564567

565-
# Cr
568+
ndim_supp = max((dependent_rv.type.ndim - rv_to_marginalize.type.ndim) for dependent_rv in dependent_rvs)
566569

567-
if any(len(dim) > 1 for dim in dependent_rvs_dims):
570+
if any(len(dim) > 1 for rv_dim_connections in dependent_rvs_dim_connections for dim in rv_dim_connections):
568571
raise NotImplementedError("Multiple dimensions are mixed")
569572

570-
# We further check that any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV
571-
# show up on the left, so that collapsing logic in logp can be more straightforward.
573+
# We further check that:
574+
# 1) Dimensions of dependent RVs are aligned with those of the marginalized RV
575+
# 2) Any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV
576+
# show up on the right, so that collapsing logic in logp can be more straightforward.
572577
# This also ensures the MarginalizedRV still behaves as an RV itself
573578
marginal_batch_ndim = rv_to_marginalize.owner.op.batch_ndim(rv_to_marginalize.owner)
574579
marginal_batch_dims = tuple((i,) for i in range(marginal_batch_ndim))
575-
for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_dims):
576-
extra_batch_ndim = (
577-
dependent_rv.type.ndim - marginal_batch_ndim - dependent_rv.owner.op.ndim_supp
578-
)
579-
valid_dependent_batch_dims = (((),) * extra_batch_ndim) + marginal_batch_dims
580+
for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_dim_connections):
581+
extra_batch_ndim = dependent_rv.type.ndim - marginal_batch_ndim
582+
valid_dependent_batch_dims = marginal_batch_dims + (((),) * extra_batch_ndim)
580583
if dependent_rv_batch_dims != valid_dependent_batch_dims:
581584
raise NotImplementedError(
582-
"Any extra batch dimensions introduced by dependent RVs must be "
583-
"on the left of dimensions introduced by the marginalized RV"
584-
)
585-
586-
for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_dims):
587-
shared_batch_dims = [
588-
batch_dim for batch_dim in dependent_rv_batch_dims if batch_dim is not None
589-
]
590-
if shared_batch_dims != sorted(shared_batch_dims):
591-
raise NotImplementedError(
592-
"Shared batch dimensions between marginalized RV and dependent RVs must be aligned positionally"
585+
"Any extra dimensions introduced by dependent RVs must appear to the right of dimensions "
586+
"introduced by the marginalized RV."
593587
)
594588

595589
input_rvs = [*marginalized_rv_input_rvs, *other_direct_rv_ancestors]
@@ -598,7 +592,22 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
598592
outputs = rvs_to_marginalize
599593
# We are strict about shared variables in SymbolicRandomVariables
600594
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)
601-
595+
# inputs = [
596+
# inp
597+
# for rv in rvs_to_marginalize # should be toposort
598+
# for inp in rv.owner.inputs
599+
# if not(all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
600+
# ]
601+
# inputs = [
602+
# inp for inp in truncated_graph_inputs(outputs, ancestors_to_include=inputs)
603+
# if not (all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
604+
# ]
605+
# inputs = truncated_graph_inputs(outputs, ancestors_to_include=[
606+
# # inp
607+
# # for output in outputs
608+
# # for inp in output.owner.inputs
609+
# # ])
610+
# inputs = [inp for inp in inputs if not isinstance(constant_fold([inp], raise_not_constant=False)[0], Constant | np.ndarray)]
602611
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
603612
marginalize_constructor = DiscreteMarginalMarkovChainRV
604613
else:
@@ -611,6 +620,14 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
611620
)
612621

613622
marginalized_rvs = marginalization_op(*inputs)
614-
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
623+
print()
624+
import pytensor
625+
pytensor.dprint(marginalized_rvs, print_type=True)
626+
fgraph.replace_all(reversed(tuple(zip(rvs_to_marginalize, marginalized_rvs))))
627+
# assert 0
628+
# fgraph.dprint()
629+
# assert 0
630+
# toposort_replace(fgraph, tuple(zip(rvs_to_marginalize, marginalized_rvs)))
631+
# assert 0
615632
return rvs_to_marginalize, marginalized_rvs
616633

tests/model/marginal/test_graph_analysis.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def test_subtensor(self):
3939
[dims] = subgraph_dim_connection(inp, [], [valid_out])
4040
assert dims == ((), (2,))
4141

42+
def test_advanced_subtensor(self):
43+
raise NotImplementedError()
44+
4245
def test_elemwise(self):
4346
inp = pt.zeros(shape=(5, 5))
4447

@@ -93,5 +96,3 @@ def test_symbolic_random_variable(self):
9396
[dims] = subgraph_dim_connection(inp, [], [out])
9497
assert dims == ((0, 2), (1, 2))
9598

96-
def test_advanced_indexing(self):
97-
raise NotImplementedError()

tests/model/marginal/test_marginal_model.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -133,36 +133,52 @@ def test_rv_dependent_multiple_marginalized_rvs():
133133
np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3)
134134

135135

136-
def test_nested_marginalized_rvs():
136+
@pytest.mark.parametrize("batched", (False, True))
137+
def test_nested_marginalized_rvs(batched):
137138
"""Test that marginalization works when there are nested marginalized RVs"""
138139

139-
with MarginalModel() as m:
140-
sigma = pm.HalfNormal("sigma")
140+
def build_model(build_batched: bool) -> MarginalModel:
141+
idx_shape = (3,) if build_batched else ()
142+
sub_idx_shape = (3, 5) if build_batched else (5,)
141143

142-
idx = pm.Bernoulli("idx", p=0.75)
143-
dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma)
144+
with MarginalModel() as m:
145+
sigma = pm.HalfNormal("sigma")
144146

145-
sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95), shape=(5,))
146-
sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma, shape=(5,))
147+
idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape)
148+
dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma)
147149

148-
ref_logp_fn = m.compile_logp(vars=[idx, dep, sub_idx, sub_dep])
150+
sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95)
151+
if build_batched:
152+
sub_idx_p = sub_idx_p[:, None]
153+
dep = dep[:, None]
154+
sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape)
155+
sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma)
149156

150-
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
151-
m.marginalize([idx, sub_idx])
157+
return m
152158

153-
assert set(m.marginalized_rvs) == {idx, sub_idx}
159+
m = build_model(build_batched=batched)
160+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
161+
m.marginalize(["idx", "sub_idx"])
162+
assert sorted(m.name for m in m.marginalized_rvs) == ["idx", "sub_idx"]
163+
return
154164

155165
# Test logp
156-
test_point = m.initial_point()
157-
test_point["dep"] = 1000
158-
test_point["sub_dep"] = np.full((5,), 1000 + 100)
166+
ref_m = build_model(build_batched=False)
167+
ref_logp_fn = ref_m.compile_logp(vars=[ref_m["idx"], ref_m["dep"], ref_m["sub_idx"], ref_m["sub_dep"]])
159168

169+
test_point = ref_m.initial_point()
170+
test_point["dep"] = np.full_like(test_point["dep"], 1000)
171+
test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
160172
ref_logp = [
161173
ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}})
162174
for idx in (0, 1)
163175
for sub_idxs in itertools.product((0, 1), repeat=5)
164176
]
165-
logp = m.compile_logp(vars=[dep, sub_dep])(test_point)
177+
178+
test_point = m.initial_point()
179+
test_point["dep"] = np.full_like(test_point["dep"], 1000)
180+
test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
181+
logp = m.compile_logp(vars=[m["dep"], m["sub_dep"]])(test_point)
166182

167183
np.testing.assert_almost_equal(
168184
logp,
@@ -615,8 +631,8 @@ def test_change_point_model_sampling(self, disaster_model):
615631

616632
def test_k_censored_clusters_model(self):
617633
def build_model(batch: bool) -> MarginalModel:
618-
data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]).T
619-
nobs = data.shape[-1]
634+
data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
635+
nobs = data.shape[0]
620636
n_clusters = 5
621637
coords = {
622638
"cluster": range(n_clusters),
@@ -641,17 +657,17 @@ def build_model(batch: bool) -> MarginalModel:
641657
initval=np.linspace(-1, 1, n_clusters),
642658
)
643659
mu_y = pm.Normal("mu_y", dims=["cluster"])
644-
mu = pm.math.concatenate([mu_x[None], mu_y[None]], axis=0) # (ndim, cluster)
660+
mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim)
645661

646662
sigma = pm.HalfNormal("sigma")
647663

648664
y = pm.Censored(
649665
"y",
650-
dist=pm.Normal.dist(mu[:, idx], sigma),
666+
dist=pm.Normal.dist(mu[idx, :], sigma),
651667
lower=-3,
652668
upper=3,
653669
observed=data,
654-
dims=["ndim", "obs"],
670+
dims=["obs", "ndim"],
655671
)
656672

657673
return m

0 commit comments

Comments
 (0)