File tree Expand file tree Collapse file tree 1 file changed +4
-6
lines changed
pymc_experimental/model/transforms Expand file tree Collapse file tree 1 file changed +4
-6
lines changed Original file line number Diff line number Diff line change 21
21
)
22
22
from pymc .pytensorf import toposort_replace
23
23
from pytensor .graph .basic import Apply , Variable
24
+ from pytensor .tensor .basic import infer_static_shape
24
25
from pytensor .tensor .random .op import RandomVariable
25
26
26
27
_log = logging .getLogger ("pmx" )
@@ -176,12 +177,9 @@ def vip_reparam_node(
176
177
) -> Tuple [ModelDeterministic , ModelNamed ]:
177
178
if not isinstance (node .op , RandomVariable | SymbolicRandomVariable ):
178
179
raise TypeError ("Op should be RandomVariable type" )
179
- _ , size , * _ = node .inputs
180
- eval_size = size .eval (mode = "FAST_COMPILE" )
181
- if eval_size is not None :
182
- rv_shape = tuple (eval_size )
183
- else :
184
- rv_shape = ()
180
+ rv = node .default_output ()
181
+ rv_shape_t , _ = infer_static_shape (rv .shape )
182
+ rv_shape = pt .as_tensor (rv_shape_t ).eval (mode = "FAST_COMPILE" )
185
183
lam_name = f"{ name } ::lam_logit__"
186
184
_log .debug (f"Creating { lam_name } with shape of { rv_shape } " )
187
185
logit_lam_ = pytensor .shared (
You can’t perform that action at this time.
0 commit comments