1
+ from typing import Sequence
2
+
3
+ from pymc import STEP_METHODS
4
+ from pytensor .tensor .random .type import RandomGeneratorType
5
+
6
+ from pytensor .compile .builders import OpFromGraph
7
+
8
+ from pymc_experimental .sampling .mcmc import posterior_optimization_db
9
+ from pymc_experimental .sampling .optimizations .conjugate_sampler import ConjugateRV , ConjugateRVSampler
10
+
11
+ STEP_METHODS .append (ConjugateRVSampler )
12
+
13
+ from pytensor .graph .fg import Output
14
+ from pytensor .tensor .elemwise import DimShuffle
15
+ from pymc .model .fgraph import model_free_rv , ModelValuedVar
16
+
17
+
18
+ from pytensor .graph .basic import Variable
19
+ from pytensor .graph .fg import FunctionGraph
20
+ from pytensor .graph .rewriting .basic import node_rewriter
21
+ from pymc .model .fgraph import ModelFreeRV
22
+ from pymc .distributions import Beta , Binomial
23
+ from pymc .pytensorf import collect_default_updates
24
+
25
+
26
+ def get_model_var_of_rv (fgraph : FunctionGraph , rv : Variable ) -> Variable :
27
+ """Return the Model dummy var that wraps the RV"""
28
+ for client , _ in fgraph .clients [rv ]:
29
+ if isinstance (client .op , ModelValuedVar ):
30
+ return client .outputs [0 ]
31
+
32
+
33
+ def get_dist_params (rv : Variable ) -> tuple [Variable ]:
34
+ return rv .owner .op .dist_params (rv .owner )
35
+
36
+
37
+ def rv_used_by (fgraph : FunctionGraph , rv : Variable , used_by_type : type , used_as_arg_idx : int | Sequence [int ], strict : bool = True ) -> list [Variable ]:
38
+ """Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
39
+
40
+ RV may be used directly or broadcasted before being used.
41
+
42
+ Parameters
43
+ ----------
44
+ fgraph : FunctionGraph
45
+ The function graph containing the RVs
46
+ rv : Variable
47
+ The RV to check for uses.
48
+ used_by_type : type
49
+ The type of operation that may use the RV.
50
+ used_as_arg_idx : int | Sequence[int]
51
+ The index of the RV in the operation's inputs.
52
+ strict : bool, default=True
53
+ If True, return no results when the RV is used in an unrecognized way.
54
+
55
+ """
56
+ if isinstance (used_as_arg_idx , int ):
57
+ used_as_arg_idx = (used_as_arg_idx ,)
58
+
59
+ clients = fgraph .clients
60
+ used_by : list [Variable ] = []
61
+ for client , inp_idx in clients [rv ]:
62
+ if isinstance (client .op , Output ):
63
+ continue
64
+
65
+ if isinstance (client .op , used_by_type ) and inp_idx in used_as_arg_idx :
66
+ # RV is directly used by the RV type
67
+ used_by .append (client .default_output ())
68
+
69
+ elif isinstance (client .op , DimShuffle ) and client .op .is_left_expand_dims :
70
+ for sub_client , sub_inp_idx in clients [client .outputs [0 ]]:
71
+ if isinstance (sub_client .op , used_by_type ) and sub_inp_idx in used_as_arg_idx :
72
+ # RV is broadcasted and then used by the RV type
73
+ used_by .append (sub_client .default_output ())
74
+ elif strict :
75
+ # Some other unrecognized use, bail out
76
+ return []
77
+ elif strict :
78
+ # Some other unrecognized use, bail out
79
+ return []
80
+
81
+ return used_by
82
+
83
+
84
+ def wrap_rv_and_conjugate_rv (fgraph : FunctionGraph , rv : Variable , conjugate_rv : Variable , inputs : Sequence [Variable ]) -> Variable :
85
+ """Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
86
+
87
+ Also takes care of handling the random number generators used in the conjugate posterior.
88
+ """
89
+ rngs , next_rngs = zip (* collect_default_updates (conjugate_rv , inputs = [rv , * inputs ]).items ())
90
+ for rng in rngs :
91
+ if rng not in fgraph .inputs :
92
+ fgraph .add_input (rng )
93
+ conjugate_op = ConjugateRV (inputs = [rv , * inputs , * rngs ], outputs = [rv , conjugate_rv , * next_rngs ])
94
+ return conjugate_op (rv , * inputs , * rngs )[0 ]
95
+
96
+
97
+ def create_untransformed_free_rv (fgraph : FunctionGraph , rv : Variable , name : str , dims : Sequence [str | Variable ]) -> Variable :
98
+ """Create a model FreeRV without transform."""
99
+ transform = None
100
+ value = rv .type (name = name )
101
+ fgraph .add_input (value )
102
+ free_rv = model_free_rv (rv , value , transform , * dims )
103
+ free_rv .name = name
104
+ return free_rv
105
+
106
+
107
+ @node_rewriter (tracks = [ModelFreeRV ])
108
+ def beta_binomial_conjugacy (fgraph : FunctionGraph , node ):
109
+ """This applies the equivalence (up to a normalizing constant) described in:
110
+
111
+ https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
112
+ """
113
+ [beta_free_rv ] = node .outputs
114
+ beta_rv , beta_value , * beta_dims = node .inputs
115
+
116
+ if not isinstance (beta_rv .owner .op , Beta ):
117
+ return None
118
+
119
+ p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p)
120
+ binomial_rvs = rv_used_by (fgraph , beta_free_rv , Binomial , p_arg_idx )
121
+
122
+ if len (binomial_rvs ) != 1 :
123
+ # Question: Can we apply conjugacy when RV is used by more than one binomial?
124
+ return None
125
+
126
+ [binomial_rv ] = binomial_rvs
127
+
128
+ binomial_model_var = get_model_var_of_rv (fgraph , binomial_rv )
129
+ if binomial_model_var is None :
130
+ return None
131
+
132
+ # We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
133
+ a , b = get_dist_params (beta_rv )
134
+ n , _ = get_dist_params (binomial_rv )
135
+
136
+ # Use value of y in new graph to avoid circularity
137
+ y = binomial_model_var .owner .inputs [1 ]
138
+
139
+ conjugate_a = a + y
140
+ conjugate_b = b + (n - y )
141
+ extra_dims = range (binomial_rv .type .ndim - beta_rv .type .ndim )
142
+ if extra_dims :
143
+ conjugate_a = conjugate_a .sum (extra_dims )
144
+ conjugate_b = conjugate_b .sum (extra_dims )
145
+ conjugate_beta_rv = Beta .dist (conjugate_a , conjugate_b )
146
+
147
+ new_beta_rv = wrap_rv_and_conjugate_rv (fgraph , beta_rv , conjugate_beta_rv , [a , b , n , y ])
148
+ new_beta_free_rv = create_untransformed_free_rv (fgraph , new_beta_rv , beta_free_rv .name , beta_dims )
149
+ return [new_beta_free_rv ]
150
+
151
+
152
+ posterior_optimization_db .register (
153
+ beta_binomial_conjugacy .__name__ ,
154
+ beta_binomial_conjugacy ,
155
+ "conjugacy"
156
+ )
0 commit comments