13
13
from pymc .distributions .transforms import Chain
14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
- from pymc .pytensorf import compile_pymc , constant_fold , toposort_replace
16
+ from pymc .pytensorf import compile_pymc , constant_fold
17
17
from pymc .util import RandomState , _get_seeds_per_chain , treedict
18
18
from pytensor .graph import FunctionGraph , clone_replace
19
- from pytensor .graph .basic import truncated_graph_inputs , Constant , ancestors
20
19
from pytensor .graph .replace import vectorize_graph
21
- from pytensor .tensor import TensorVariable , extract_constant
20
+ from pytensor .tensor import TensorVariable
22
21
from pytensor .tensor .special import log_softmax
23
22
24
23
__all__ = ["MarginalModel" , "marginalize" ]
25
24
26
25
from pymc_experimental .distributions import DiscreteMarkovChain
27
- from pymc_experimental .model .marginal .distributions import FiniteDiscreteMarginalRV , DiscreteMarginalMarkovChainRV , \
28
- get_domain_of_finite_discrete_rv , _add_reduce_batch_dependent_logps
29
- from pymc_experimental .model .marginal .graph_analysis import find_conditional_input_rvs , is_conditional_dependent , \
30
- find_conditional_dependent_rvs , subgraph_dim_connection , collect_shared_vars
26
+ from pymc_experimental .model .marginal .distributions import (
27
+ DiscreteMarginalMarkovChainRV ,
28
+ FiniteDiscreteMarginalRV ,
29
+ _add_reduce_batch_dependent_logps ,
30
+ get_domain_of_finite_discrete_rv ,
31
+ )
32
+ from pymc_experimental .model .marginal .graph_analysis import (
33
+ collect_shared_vars ,
34
+ find_conditional_dependent_rvs ,
35
+ find_conditional_input_rvs ,
36
+ is_conditional_dependent ,
37
+ subgraph_dim_connection ,
38
+ )
31
39
32
40
ModelRVs = TensorVariable | Sequence [TensorVariable ] | str | Sequence [str ]
33
41
@@ -537,10 +545,6 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
537
545
538
546
539
547
def replace_finite_discrete_marginal_subgraph (fgraph , rv_to_marginalize , all_rvs ):
540
- # TODO: This should eventually be integrated in a more general routine that can
541
- # identify other types of supported marginalization, of which finite discrete
542
- # RVs is just one
543
-
544
548
dependent_rvs = find_conditional_dependent_rvs (rv_to_marginalize , all_rvs )
545
549
if not dependent_rvs :
546
550
raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
@@ -552,7 +556,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
552
556
if rv is not rv_to_marginalize
553
557
]
554
558
555
- if all ( rv_to_marginalize .type .broadcastable ) :
559
+ if rv_to_marginalize .type .ndim == 0 :
556
560
ndim_supp = max ([dependent_rv .type .ndim for dependent_rv in dependent_rvs ])
557
561
else :
558
562
# If the marginalized RV has multiple dimensions, check that graph between
@@ -561,23 +565,27 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
561
565
dependent_rvs_dim_connections = subgraph_dim_connection (
562
566
rv_to_marginalize , other_direct_rv_ancestors , dependent_rvs
563
567
)
564
- # dependent_rvs_dim_connections = subgraph_dim_connection(
565
- # rv_to_marginalize, other_inputs, dependent_rvs
566
- # )
567
568
568
- ndim_supp = max ((dependent_rv .type .ndim - rv_to_marginalize .type .ndim ) for dependent_rv in dependent_rvs )
569
+ ndim_supp = max (
570
+ (dependent_rv .type .ndim - rv_to_marginalize .type .ndim ) for dependent_rv in dependent_rvs
571
+ )
569
572
570
- if any (len (dim ) > 1 for rv_dim_connections in dependent_rvs_dim_connections for dim in rv_dim_connections ):
573
+ if any (
574
+ len (dim ) > 1
575
+ for rv_dim_connections in dependent_rvs_dim_connections
576
+ for dim in rv_dim_connections
577
+ ):
571
578
raise NotImplementedError ("Multiple dimensions are mixed" )
572
579
573
580
# We further check that:
574
581
# 1) Dimensions of dependent RVs are aligned with those of the marginalized RV
575
582
# 2) Any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV
576
583
# show up on the right, so that collapsing logic in logp can be more straightforward.
577
- # This also ensures the MarginalizedRV still behaves as an RV itself
578
584
marginal_batch_ndim = rv_to_marginalize .owner .op .batch_ndim (rv_to_marginalize .owner )
579
585
marginal_batch_dims = tuple ((i ,) for i in range (marginal_batch_ndim ))
580
- for dependent_rv , dependent_rv_batch_dims in zip (dependent_rvs , dependent_rvs_dim_connections ):
586
+ for dependent_rv , dependent_rv_batch_dims in zip (
587
+ dependent_rvs , dependent_rvs_dim_connections
588
+ ):
581
589
extra_batch_ndim = dependent_rv .type .ndim - marginal_batch_ndim
582
590
valid_dependent_batch_dims = marginal_batch_dims + (((),) * extra_batch_ndim )
583
591
if dependent_rv_batch_dims != valid_dependent_batch_dims :
@@ -587,47 +595,21 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
587
595
)
588
596
589
597
input_rvs = [* marginalized_rv_input_rvs , * other_direct_rv_ancestors ]
590
- rvs_to_marginalize = [rv_to_marginalize , * dependent_rvs ]
598
+ output_rvs = [rv_to_marginalize , * dependent_rvs ]
591
599
592
- outputs = rvs_to_marginalize
593
600
# We are strict about shared variables in SymbolicRandomVariables
594
- inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
595
- # inputs = [
596
- # inp
597
- # for rv in rvs_to_marginalize # should be toposort
598
- # for inp in rv.owner.inputs
599
- # if not(all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
600
- # ]
601
- # inputs = [
602
- # inp for inp in truncated_graph_inputs(outputs, ancestors_to_include=inputs)
603
- # if not (all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
604
- # ]
605
- # inputs = truncated_graph_inputs(outputs, ancestors_to_include=[
606
- # # inp
607
- # # for output in outputs
608
- # # for inp in output.owner.inputs
609
- # # ])
610
- # inputs = [inp for inp in inputs if not isinstance(constant_fold([inp], raise_not_constant=False)[0], Constant | np.ndarray)]
601
+ inputs = input_rvs + collect_shared_vars (output_rvs , blockers = input_rvs )
602
+
611
603
if isinstance (rv_to_marginalize .owner .op , DiscreteMarkovChain ):
612
604
marginalize_constructor = DiscreteMarginalMarkovChainRV
613
605
else :
614
606
marginalize_constructor = FiniteDiscreteMarginalRV
615
607
616
608
marginalization_op = marginalize_constructor (
617
609
inputs = inputs ,
618
- outputs = outputs ,
610
+ outputs = output_rvs , # TODO: Add RNG updates to outputs
619
611
ndim_supp = ndim_supp ,
620
612
)
621
-
622
- marginalized_rvs = marginalization_op (* inputs )
623
- print ()
624
- import pytensor
625
- pytensor .dprint (marginalized_rvs , print_type = True )
626
- fgraph .replace_all (reversed (tuple (zip (rvs_to_marginalize , marginalized_rvs ))))
627
- # assert 0
628
- # fgraph.dprint()
629
- # assert 0
630
- # toposort_replace(fgraph, tuple(zip(rvs_to_marginalize, marginalized_rvs)))
631
- # assert 0
632
- return rvs_to_marginalize , marginalized_rvs
633
-
613
+ new_output_rvs = marginalization_op (* inputs )
614
+ fgraph .replace_all (tuple (zip (output_rvs , new_output_rvs )))
615
+ return output_rvs , new_output_rvs
0 commit comments