Skip to content

Commit 01ed4c0

Browse files
committed
.WIP
1 parent 6169cf3 commit 01ed4c0

File tree

4 files changed

+117
-110
lines changed

4 files changed

+117
-110
lines changed

pymc_experimental/model/marginal/distributions.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
1-
from typing import Sequence
1+
from collections.abc import Sequence
22

33
import numpy as np
44
import pytensor.tensor as pt
5-
from pymc.distributions import (
6-
Bernoulli,
7-
Categorical,
8-
DiscreteUniform,
9-
SymbolicRandomVariable
10-
)
11-
from pymc.logprob.basic import conditional_logp, logp
5+
6+
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable
127
from pymc.logprob.abstract import _logprob
8+
from pymc.logprob.basic import conditional_logp, logp
139
from pymc.pytensorf import constant_fold
14-
from pytensor.graph.replace import clone_replace, graph_replace
15-
from pytensor.scan import scan, map as scan_map
1610
from pytensor.compile.mode import Mode
1711
from pytensor.graph import vectorize_graph
18-
from pytensor.tensor import TensorVariable, TensorType
12+
from pytensor.graph.replace import clone_replace, graph_replace
13+
from pytensor.scan import map as scan_map
14+
from pytensor.scan import scan
15+
from pytensor.tensor import TensorType, TensorVariable
1916

2017
from pymc_experimental.distributions import DiscreteMarkovChain
2118

@@ -80,8 +77,6 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
8077
inner_rv_values = dict(zip(inner_rvs, values))
8178
marginalized_vv = marginalized_rv.clone()
8279
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
83-
print("")
84-
print("Inner conditional logp call >> ")
8580
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
8681

8782
# Reduce logp dimensions corresponding to broadcasted variables

pymc_experimental/model/marginal/graph_analysis.py

+57-31
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
from itertools import zip_longest, chain
2-
from typing import Sequence
1+
from collections.abc import Sequence
2+
from itertools import chain, zip_longest
33

44
from pymc import SymbolicRandomVariable
55
from pytensor.compile import SharedVariable
6-
from pytensor.graph import ancestors, Constant, graph_inputs, Variable
6+
from pytensor.graph import Constant, Variable, ancestors, graph_inputs
77
from pytensor.graph.basic import io_toposort
8-
from pytensor.tensor import TensorVariable, TensorType
8+
from pytensor.tensor import TensorType, TensorVariable
99
from pytensor.tensor.blockwise import Blockwise
10-
from pytensor.tensor.elemwise import DimShuffle, Elemwise, CAReduce
10+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1111
from pytensor.tensor.random.op import RandomVariable
1212
from pytensor.tensor.rewriting.subtensor import is_full_slice
1313
from pytensor.tensor.shape import Shape
14-
from pytensor.tensor.subtensor import Subtensor, get_idx_list, AdvancedSubtensor
14+
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
1515
from pytensor.tensor.type_other import NoneTypeT
1616

17+
from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV
18+
1719

1820
def static_shape_ancestors(vars):
1921
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
@@ -58,7 +60,6 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
5860
]
5961

6062

61-
6263
def collect_shared_vars(outputs, blockers):
6364
return [
6465
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
@@ -86,32 +87,24 @@ def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]:
8687
return adv_group_axis, adv_group_ndim
8788

8889

89-
def _broadcast_dims(inputs_dims: Sequence[tuple[tuple[int, ...], ...]]) -> tuple[tuple[int, ...], ...]:
90+
def _broadcast_dims(
91+
inputs_dims: Sequence[tuple[tuple[int, ...], ...]],
92+
) -> tuple[tuple[int, ...], ...]:
9093
output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0)
9194
# Add missing dims
92-
inputs_dims = [
93-
((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims
94-
]
95+
inputs_dims = [((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims]
9596
# Combine aligned dims
96-
output_dims = tuple(tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims))
97+
output_dims = tuple(
98+
tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims)
99+
)
97100
return output_dims
98101

99102

100-
def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[tuple[int, ...], ...]]:
101-
"""Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs.
102-
103-
Raises
104-
------
105-
NotImplementedError
106-
If variable related to marginalized batch_dims is used in an operation that is not yet supported
107-
108-
"""
103+
VAR_DIMS = dict[Variable, tuple[tuple[int, ...], ...]]
109104

110-
var_dims: dict[Variable, tuple[tuple[int, ...], ...]] = {
111-
input_var: tuple((i,) for i in range(input_var.type.ndim))
112-
}
113105

114-
for node in io_toposort([input_var, *other_inputs], output_vars):
106+
def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR_DIMS:
107+
for node in io_toposort(input_vars, output_vars):
115108
inputs_dims = [var_dims.get(inp, ()) for inp in node.inputs]
116109

117110
if not any(inputs_dims):
@@ -126,6 +119,20 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
126119
)
127120
var_dims[node.outputs[0]] = output_dims
128121

122+
elif isinstance(node.op, FiniteDiscreteMarginalRV):
123+
# FiniteDiscreteMarginalRV does not behave like a standard SymbolicRandomVariable, due to how we truncate the graph.
124+
# We analyze the inner graph of the Marginalized RV to find the true dim connections
125+
inner_var_dims = {
126+
inner_inp: input_dims
127+
for inner_inp, input_dims in zip(node.op.inner_inputs, inputs_dims)
128+
}
129+
inner_var_dims = _subgraph_dim_connection(
130+
inner_var_dims, node.op.inner_inputs, node.op.inner_outputs
131+
)
132+
for out, inner_out in zip(node.outputs, node.op.inner_outputs):
133+
if inner_out in inner_var_dims:
134+
var_dims[out] = inner_var_dims[inner_out]
135+
129136
elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable):
130137
# NOTE: User-provided CustomDist may not respect core dimensions on the left.
131138

@@ -135,13 +142,16 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
135142
op_batch_ndim = node.op.batch_ndim(node)
136143

137144
# Collapse all core_dims
138-
core_dims = tuple(sorted(chain.from_iterable([i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]])))
139-
batch_dims = _broadcast_dims(
140-
tuple(
141-
input_dims[:op_batch_ndim]
142-
for input_dims in inputs_dims
145+
core_dims = tuple(
146+
sorted(
147+
chain.from_iterable(
148+
[i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]]
149+
)
143150
)
144151
)
152+
batch_dims = _broadcast_dims(
153+
tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims)
154+
)
145155
# Add batch dims to each output_dims
146156
batch_dims = tuple(batch_dim + core_dims for batch_dim in batch_dims)
147157
for out in node.outputs:
@@ -221,7 +231,7 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
221231
elif value_dim:
222232
# We are trying to partially slice or index a known dimension
223233
raise NotImplementedError(
224-
f"Partial slicing or advanced integer indexing of known dimensions not supported"
234+
"Partial slicing or advanced integer indexing of known dimensions not supported"
225235
)
226236
elif isinstance(idx, slice):
227237
# Unknown dimensions kept by partial slice.
@@ -252,4 +262,20 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
252262
else:
253263
raise NotImplementedError(f"Marginalization through operation {node} not supported")
254264

265+
return var_dims
266+
267+
268+
def subgraph_dim_connection(
269+
input_var, other_inputs, output_vars
270+
) -> list[tuple[tuple[int, ...], ...]]:
271+
"""Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs.
272+
273+
Raises
274+
------
275+
NotImplementedError
276+
If variable related to marginalized batch_dims is used in an operation that is not yet supported
277+
278+
"""
279+
var_dims = {input_var: tuple((i,) for i in range(input_var.type.ndim))}
280+
var_dims = _subgraph_dim_connection(var_dims, [input_var, *other_inputs], output_vars)
255281
return [var_dims[output_rv] for output_rv in output_vars]

pymc_experimental/model/marginal/marginal_model.py

+34-52
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,29 @@
1313
from pymc.distributions.transforms import Chain
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
16-
from pymc.pytensorf import compile_pymc, constant_fold, toposort_replace
16+
from pymc.pytensorf import compile_pymc, constant_fold
1717
from pymc.util import RandomState, _get_seeds_per_chain, treedict
1818
from pytensor.graph import FunctionGraph, clone_replace
19-
from pytensor.graph.basic import truncated_graph_inputs, Constant, ancestors
2019
from pytensor.graph.replace import vectorize_graph
21-
from pytensor.tensor import TensorVariable, extract_constant
20+
from pytensor.tensor import TensorVariable
2221
from pytensor.tensor.special import log_softmax
2322

2423
__all__ = ["MarginalModel", "marginalize"]
2524

2625
from pymc_experimental.distributions import DiscreteMarkovChain
27-
from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV, DiscreteMarginalMarkovChainRV, \
28-
get_domain_of_finite_discrete_rv, _add_reduce_batch_dependent_logps
29-
from pymc_experimental.model.marginal.graph_analysis import find_conditional_input_rvs, is_conditional_dependent, \
30-
find_conditional_dependent_rvs, subgraph_dim_connection, collect_shared_vars
26+
from pymc_experimental.model.marginal.distributions import (
27+
DiscreteMarginalMarkovChainRV,
28+
FiniteDiscreteMarginalRV,
29+
_add_reduce_batch_dependent_logps,
30+
get_domain_of_finite_discrete_rv,
31+
)
32+
from pymc_experimental.model.marginal.graph_analysis import (
33+
collect_shared_vars,
34+
find_conditional_dependent_rvs,
35+
find_conditional_input_rvs,
36+
is_conditional_dependent,
37+
subgraph_dim_connection,
38+
)
3139

3240
ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]
3341

@@ -537,10 +545,6 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
537545

538546

539547
def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
540-
# TODO: This should eventually be integrated in a more general routine that can
541-
# identify other types of supported marginalization, of which finite discrete
542-
# RVs is just one
543-
544548
dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
545549
if not dependent_rvs:
546550
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
@@ -552,7 +556,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
552556
if rv is not rv_to_marginalize
553557
]
554558

555-
if all (rv_to_marginalize.type.broadcastable):
559+
if rv_to_marginalize.type.ndim == 0:
556560
ndim_supp = max([dependent_rv.type.ndim for dependent_rv in dependent_rvs])
557561
else:
558562
# If the marginalized RV has multiple dimensions, check that graph between
@@ -561,23 +565,27 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
561565
dependent_rvs_dim_connections = subgraph_dim_connection(
562566
rv_to_marginalize, other_direct_rv_ancestors, dependent_rvs
563567
)
564-
# dependent_rvs_dim_connections = subgraph_dim_connection(
565-
# rv_to_marginalize, other_inputs, dependent_rvs
566-
# )
567568

568-
ndim_supp = max((dependent_rv.type.ndim - rv_to_marginalize.type.ndim) for dependent_rv in dependent_rvs)
569+
ndim_supp = max(
570+
(dependent_rv.type.ndim - rv_to_marginalize.type.ndim) for dependent_rv in dependent_rvs
571+
)
569572

570-
if any(len(dim) > 1 for rv_dim_connections in dependent_rvs_dim_connections for dim in rv_dim_connections):
573+
if any(
574+
len(dim) > 1
575+
for rv_dim_connections in dependent_rvs_dim_connections
576+
for dim in rv_dim_connections
577+
):
571578
raise NotImplementedError("Multiple dimensions are mixed")
572579

573580
# We further check that:
574581
# 1) Dimensions of dependent RVs are aligned with those of the marginalized RV
575582
# 2) Any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV
576583
# show up on the right, so that collapsing logic in logp can be more straightforward.
577-
# This also ensures the MarginalizedRV still behaves as an RV itself
578584
marginal_batch_ndim = rv_to_marginalize.owner.op.batch_ndim(rv_to_marginalize.owner)
579585
marginal_batch_dims = tuple((i,) for i in range(marginal_batch_ndim))
580-
for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_dim_connections):
586+
for dependent_rv, dependent_rv_batch_dims in zip(
587+
dependent_rvs, dependent_rvs_dim_connections
588+
):
581589
extra_batch_ndim = dependent_rv.type.ndim - marginal_batch_ndim
582590
valid_dependent_batch_dims = marginal_batch_dims + (((),) * extra_batch_ndim)
583591
if dependent_rv_batch_dims != valid_dependent_batch_dims:
@@ -587,47 +595,21 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
587595
)
588596

589597
input_rvs = [*marginalized_rv_input_rvs, *other_direct_rv_ancestors]
590-
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]
598+
output_rvs = [rv_to_marginalize, *dependent_rvs]
591599

592-
outputs = rvs_to_marginalize
593600
# We are strict about shared variables in SymbolicRandomVariables
594-
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)
595-
# inputs = [
596-
# inp
597-
# for rv in rvs_to_marginalize # should be toposort
598-
# for inp in rv.owner.inputs
599-
# if not(all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
600-
# ]
601-
# inputs = [
602-
# inp for inp in truncated_graph_inputs(outputs, ancestors_to_include=inputs)
603-
# if not (all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
604-
# ]
605-
# inputs = truncated_graph_inputs(outputs, ancestors_to_include=[
606-
# # inp
607-
# # for output in outputs
608-
# # for inp in output.owner.inputs
609-
# # ])
610-
# inputs = [inp for inp in inputs if not isinstance(constant_fold([inp], raise_not_constant=False)[0], Constant | np.ndarray)]
601+
inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs)
602+
611603
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
612604
marginalize_constructor = DiscreteMarginalMarkovChainRV
613605
else:
614606
marginalize_constructor = FiniteDiscreteMarginalRV
615607

616608
marginalization_op = marginalize_constructor(
617609
inputs=inputs,
618-
outputs=outputs,
610+
outputs=output_rvs, # TODO: Add RNG updates to outputs
619611
ndim_supp=ndim_supp,
620612
)
621-
622-
marginalized_rvs = marginalization_op(*inputs)
623-
print()
624-
import pytensor
625-
pytensor.dprint(marginalized_rvs, print_type=True)
626-
fgraph.replace_all(reversed(tuple(zip(rvs_to_marginalize, marginalized_rvs))))
627-
# assert 0
628-
# fgraph.dprint()
629-
# assert 0
630-
# toposort_replace(fgraph, tuple(zip(rvs_to_marginalize, marginalized_rvs)))
631-
# assert 0
632-
return rvs_to_marginalize, marginalized_rvs
633-
613+
new_output_rvs = marginalization_op(*inputs)
614+
fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs)))
615+
return output_rvs, new_output_rvs

0 commit comments

Comments
 (0)