Skip to content

Commit a9f582c

Browse files
committed
refactor: Use infer_static_shape from pytensor
1 parent 75c98ec commit a9f582c

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from pymc.pytensorf import toposort_replace
2323
from pytensor.graph.basic import Apply, Variable
24+
from pytensor.tensor.basic import infer_static_shape
2425
from pytensor.tensor.random.op import RandomVariable
2526

2627
_log = logging.getLogger("pmx")
@@ -176,12 +177,9 @@ def vip_reparam_node(
176177
) -> Tuple[ModelDeterministic, ModelNamed]:
177178
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
178179
raise TypeError("Op should be RandomVariable type")
179-
_, size, *_ = node.inputs
180-
eval_size = size.eval(mode="FAST_COMPILE")
181-
if eval_size is not None:
182-
rv_shape = tuple(eval_size)
183-
else:
184-
rv_shape = ()
180+
rv = node.default_output()
181+
rv_shape_t, _ = infer_static_shape(rv.shape)
182+
rv_shape = pt.as_tensor(rv_shape_t).eval(mode="FAST_COMPILE")
185183
lam_name = f"{name}::lam_logit__"
186184
_log.debug(f"Creating {lam_name} with shape of {rv_shape}")
187185
logit_lam_ = pytensor.shared(

0 commit comments

Comments
 (0)