Skip to content

Commit 9de291b

Browse files
committed
Revert "fix: use constant_fold instead of manual eval"
This reverts commit 557caec.
1 parent 557caec commit 9de291b

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
model_from_fgraph,
2020
model_named,
2121
)
22-
from pymc.pytensorf import constant_fold, toposort_replace
22+
from pymc.pytensorf import toposort_replace
2323
from pytensor.graph.basic import Apply, Variable
2424
from pytensor.tensor.random.op import RandomVariable
2525

@@ -176,8 +176,12 @@ def vip_reparam_node(
176176
) -> Tuple[ModelDeterministic, ModelNamed]:
177177
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
178178
raise TypeError("Op should be RandomVariable type")
179-
rv = node.default_output()
180-
[rv_shape] = constant_fold([rv.shape], raise_if_not_constant=False)
179+
_, size, *_ = node.inputs
180+
eval_size = size.eval()
181+
if eval_size is not None:
182+
rv_shape = tuple(eval_size)
183+
else:
184+
rv_shape = ()
181185
lam_name = f"{name}::lam_logit__"
182186
_log.debug(f"Creating {lam_name} with shape of {rv_shape}")
183187
logit_lam_ = pytensor.shared(

0 commit comments

Comments
 (0)