Skip to content

Commit 02d42d6

Browse files
committed
feature: Add Exponential distribution to autoreparameterization dispatch table
1 parent 87d4aea commit 02d42d6

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,44 @@ def _(
246246
return vip_rep
247247

248248

249+
@_vip_reparam_node.register
250+
def _(
251+
op: pm.Exponential,
252+
node: Apply,
253+
name: str,
254+
dims: List[Variable],
255+
transform: Optional[Transform],
256+
lam: pt.TensorVariable,
257+
) -> ModelDeterministic:
258+
rng, size, scale = node.inputs
259+
scale_centered = scale**lam
260+
scale_noncentered = scale ** (1 - lam)
261+
vip_rv_ = pm.Exponential.dist(
262+
scale=scale_centered,
263+
size=size,
264+
rng=rng,
265+
)
266+
vip_rv_value_ = vip_rv_.clone()
267+
vip_rv_.name = f"{name}::tau_"
268+
if transform is not None:
269+
vip_rv_value_.name = f"{vip_rv_.name}_{transform.name}__"
270+
else:
271+
vip_rv_value_.name = vip_rv_.name
272+
vip_rv = model_free_rv(
273+
vip_rv_,
274+
vip_rv_value_,
275+
transform,
276+
*dims,
277+
)
278+
279+
vip_rep_ = scale_noncentered * vip_rv
280+
281+
vip_rep_.name = name
282+
283+
vip_rep = model_deterministic(vip_rep_, *dims)
284+
return vip_rep
285+
286+
249287
def vip_reparametrize(
250288
model: pm.Model,
251289
var_names: Sequence[str],

0 commit comments

Comments
 (0)