@@ -147,32 +147,45 @@ class BaseTrace(IBaseTrace):
147
147
use different test point that might be with changed variables shapes
148
148
"""
149
149
150
- def __init__ (self , name , model = None , vars = None , test_point = None ):
151
- self .name = name
152
-
150
+ def __init__ (
151
+ self ,
152
+ name = None ,
153
+ model = None ,
154
+ vars = None ,
155
+ test_point = None ,
156
+ * ,
157
+ fn = None ,
158
+ var_shapes = None ,
159
+ var_dtypes = None ,
160
+ ):
153
161
model = modelcontext (model )
154
- self . model = model
162
+
155
163
if vars is None :
156
164
vars = model .unobserved_value_vars
157
165
158
166
unnamed_vars = {var for var in vars if var .name is None }
159
167
if unnamed_vars :
160
168
raise Exception (f"Can't trace unnamed variables: { unnamed_vars } " )
161
- self . vars = vars
162
- self . varnames = [ var . name for var in vars ]
163
- self . fn = model .compile_fn (vars , inputs = model .value_vars , on_unused_input = "ignore" )
169
+
170
+ if fn is None :
171
+ fn = model .compile_fn (vars , inputs = model .value_vars , on_unused_input = "ignore" )
164
172
165
173
# Get variable shapes. Most backends will need this
166
174
# information.
167
- if test_point is None :
168
- test_point = model .initial_point ()
169
- else :
170
- test_point_ = model .initial_point ().copy ()
171
- test_point_ .update (test_point )
172
- test_point = test_point_
173
- var_values = list (zip (self .varnames , self .fn (test_point )))
174
- self .var_shapes = {var : value .shape for var , value in var_values }
175
- self .var_dtypes = {var : value .dtype for var , value in var_values }
175
+ if var_shapes is None or var_dtypes is None :
176
+ if test_point is None :
177
+ test_point = model .initial_point ()
178
+ var_values = tuple (zip (vars , fn (** test_point )))
179
+ var_shapes = {var .name : value .shape for var , value in var_values }
180
+ var_dtypes = {var .name : value .dtype for var , value in var_values }
181
+
182
+ self .name = name
183
+ self .model = model
184
+ self .fn = fn
185
+ self .vars = vars
186
+ self .varnames = [var .name for var in vars ]
187
+ self .var_shapes = var_shapes
188
+ self .var_dtypes = var_dtypes
176
189
self .chain = None
177
190
self ._is_base_setup = False
178
191
self .sampler_vars = None
0 commit comments