13
13
from pymc .distributions .transforms import Chain
14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
- from pymc .pytensorf import compile_pymc , constant_fold
16
+ from pymc .pytensorf import compile_pymc , constant_fold , toposort_replace
17
17
from pymc .util import RandomState , _get_seeds_per_chain , treedict
18
18
from pytensor .graph import FunctionGraph , clone_replace
19
+ from pytensor .graph .basic import truncated_graph_inputs , Constant , ancestors
19
20
from pytensor .graph .replace import vectorize_graph
20
- from pytensor .tensor import TensorVariable
21
+ from pytensor .tensor import TensorVariable , extract_constant
21
22
from pytensor .tensor .special import log_softmax
22
23
23
24
__all__ = ["MarginalModel" , "marginalize" ]
@@ -544,52 +545,45 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
544
545
if not dependent_rvs :
545
546
raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
546
547
547
- ndim_supp = max ({rv .owner .op .ndim_supp for rv in dependent_rvs })
548
-
549
548
marginalized_rv_input_rvs = find_conditional_input_rvs ([rv_to_marginalize ], all_rvs )
550
549
other_direct_rv_ancestors = [
551
550
rv
552
551
for rv in find_conditional_input_rvs (dependent_rvs , all_rvs )
553
552
if rv is not rv_to_marginalize
554
553
]
555
554
556
- # If the marginalized RV has multiple dimensions, check that graph between
557
- # marginalized RV and dependent RVs does not mix information from batch dimensions
558
- # (otherwise logp would require enuremating over all combinations of batch dimension values)
559
- if any (not bcast for bcast in rv_to_marginalize .type .broadcastable ):
560
- # When there are batch dimensions, we call `batch_dims_subgraph` to make sure these are not mixed
561
- dependent_rvs_dims = subgraph_dim_connection (
555
+ if all (rv_to_marginalize .type .broadcastable ):
556
+ ndim_supp = max ([dependent_rv .type .ndim for dependent_rv in dependent_rvs ])
557
+ else :
558
+ # If the marginalized RV has multiple dimensions, check that graph between
559
+ # marginalized RV and dependent RVs does not mix information from batch dimensions
560
+ # (otherwise logp would require enumerating over all combinations of batch dimension values)
561
+ dependent_rvs_dim_connections = subgraph_dim_connection (
562
562
rv_to_marginalize , other_direct_rv_ancestors , dependent_rvs
563
563
)
564
+ # dependent_rvs_dim_connections = subgraph_dim_connection(
565
+ # rv_to_marginalize, other_inputs, dependent_rvs
566
+ # )
564
567
565
- # Cr
568
+ ndim_supp = max (( dependent_rv . type . ndim - rv_to_marginalize . type . ndim ) for dependent_rv in dependent_rvs )
566
569
567
- if any (len (dim ) > 1 for dim in dependent_rvs_dims ):
570
+ if any (len (dim ) > 1 for rv_dim_connections in dependent_rvs_dim_connections for dim in rv_dim_connections ):
568
571
raise NotImplementedError ("Multiple dimensions are mixed" )
569
572
570
- # We further check that any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV
571
- # show up on the left, so that collapsing logic in logp can be more straightforward.
573
+ # We further check that:
574
+ # 1) Dimensions of dependent RVs are aligned with those of the marginalized RV
575
+ # 2) Any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV
576
+ # show up on the right, so that collapsing logic in logp can be more straightforward.
572
577
# This also ensures the MarginalizedRV still behaves as an RV itself
573
578
marginal_batch_ndim = rv_to_marginalize .owner .op .batch_ndim (rv_to_marginalize .owner )
574
579
marginal_batch_dims = tuple ((i ,) for i in range (marginal_batch_ndim ))
575
- for dependent_rv , dependent_rv_batch_dims in zip (dependent_rvs , dependent_rvs_dims ):
576
- extra_batch_ndim = (
577
- dependent_rv .type .ndim - marginal_batch_ndim - dependent_rv .owner .op .ndim_supp
578
- )
579
- valid_dependent_batch_dims = (((),) * extra_batch_ndim ) + marginal_batch_dims
580
+ for dependent_rv , dependent_rv_batch_dims in zip (dependent_rvs , dependent_rvs_dim_connections ):
581
+ extra_batch_ndim = dependent_rv .type .ndim - marginal_batch_ndim
582
+ valid_dependent_batch_dims = marginal_batch_dims + (((),) * extra_batch_ndim )
580
583
if dependent_rv_batch_dims != valid_dependent_batch_dims :
581
584
raise NotImplementedError (
582
- "Any extra batch dimensions introduced by dependent RVs must be "
583
- "on the left of dimensions introduced by the marginalized RV"
584
- )
585
-
586
- for dependent_rv , dependent_rv_batch_dims in zip (dependent_rvs , dependent_rvs_dims ):
587
- shared_batch_dims = [
588
- batch_dim for batch_dim in dependent_rv_batch_dims if batch_dim is not None
589
- ]
590
- if shared_batch_dims != sorted (shared_batch_dims ):
591
- raise NotImplementedError (
592
- "Shared batch dimensions between marginalized RV and dependent RVs must be aligned positionally"
585
+ "Any extra dimensions introduced by dependent RVs must appear to the right of dimensions "
586
+ "introduced by the marginalized RV."
593
587
)
594
588
595
589
input_rvs = [* marginalized_rv_input_rvs , * other_direct_rv_ancestors ]
@@ -598,7 +592,22 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
598
592
outputs = rvs_to_marginalize
599
593
# We are strict about shared variables in SymbolicRandomVariables
600
594
inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
601
-
595
+ # inputs = [
596
+ # inp
597
+ # for rv in rvs_to_marginalize # should be toposort
598
+ # for inp in rv.owner.inputs
599
+ # if not(all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
600
+ # ]
601
+ # inputs = [
602
+ # inp for inp in truncated_graph_inputs(outputs, ancestors_to_include=inputs)
603
+ # if not (all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
604
+ # ]
605
+ # inputs = truncated_graph_inputs(outputs, ancestors_to_include=[
606
+ # # inp
607
+ # # for output in outputs
608
+ # # for inp in output.owner.inputs
609
+ # # ])
610
+ # inputs = [inp for inp in inputs if not isinstance(constant_fold([inp], raise_not_constant=False)[0], Constant | np.ndarray)]
602
611
if isinstance (rv_to_marginalize .owner .op , DiscreteMarkovChain ):
603
612
marginalize_constructor = DiscreteMarginalMarkovChainRV
604
613
else :
@@ -611,6 +620,14 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
611
620
)
612
621
613
622
marginalized_rvs = marginalization_op (* inputs )
614
- fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
623
+ print ()
624
+ import pytensor
625
+ pytensor .dprint (marginalized_rvs , print_type = True )
626
+ fgraph .replace_all (reversed (tuple (zip (rvs_to_marginalize , marginalized_rvs ))))
627
+ # assert 0
628
+ # fgraph.dprint()
629
+ # assert 0
630
+ # toposort_replace(fgraph, tuple(zip(rvs_to_marginalize, marginalized_rvs)))
631
+ # assert 0
615
632
return rvs_to_marginalize , marginalized_rvs
616
633
0 commit comments