@@ -67,17 +67,33 @@ def std(self):
67
67
def __init_group__ (self , group ):
68
68
super ().__init_group__ (group )
69
69
if not self ._check_user_params ():
70
- self .shared_params = self .create_shared_params (self ._kwargs .get ("start" , None ))
70
+ self .shared_params = self .create_shared_params (
71
+ self ._kwargs .get ("start" , None ), self ._kwargs .get ("start_sigma" , None )
72
+ )
71
73
self ._finalize_init ()
72
74
73
- def create_shared_params (self , start = None ):
75
+ def create_shared_params (self , start = None , start_sigma = None ):
76
+ # NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and
77
+ # `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free
78
+ # variables are given by `self.group` and that the mapping between original variables and flat array is given
79
+ # by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
80
+ # future code changes that break this assumption.
74
81
start = self ._prepare_start (start )
75
- rho = np . zeros (( self .ddim ,) )
82
+ rho = self ._prepare_start_sigma ( start_sigma )
76
83
return {
77
84
"mu" : aesara .shared (pm .floatX (start ), "mu" ),
78
85
"rho" : aesara .shared (pm .floatX (rho ), "rho" ),
79
86
}
80
87
88
+ def _prepare_start_sigma (self , start_sigma ):
89
+ rho = np .zeros ((self .ddim ,))
90
+ if start_sigma is not None :
91
+ for name , slice_ , * _ in self .ordering .items ():
92
+ sigma = start_sigma .get (name )
93
+ if sigma is not None :
94
+ rho [slice_ ] = np .log (np .exp (np .abs (sigma )) - 1.0 )
95
+ return rho
96
+
81
97
@node_property
82
98
def symbolic_random (self ):
83
99
initial = self .symbolic_initial
0 commit comments