Skip to content

Commit 5f248fc

Browse files
michaelosthegetwiecki
authored andcommitted
Implement method for fast evals of (un)transformed RV shapes
1 parent d8bb334 commit 5f248fc

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

pymc/model.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,14 @@
6262
from pymc.distributions.transforms import Transform
6363
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
6464
from pymc.math import flatten_list
65-
from pymc.util import UNSET, WithMemoization, get_var_name, treedict, treelist
65+
from pymc.util import (
66+
UNSET,
67+
WithMemoization,
68+
get_transformed_name,
69+
get_var_name,
70+
treedict,
71+
treelist,
72+
)
6673
from pymc.vartypes import continuous_types, discrete_types, typefilter
6774

6875
__all__ = [
@@ -1603,6 +1610,33 @@ def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]):
16031610

16041611
a.update({k: v for k, v in b.items() if k not in a})
16051612

1613+
def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]:
1614+
"""Evaluates shapes of untransformed AND transformed free variables.
1615+
1616+
Returns
1617+
-------
1618+
shapes : dict
1619+
Maps untransformed and transformed variable names to shape tuples.
1620+
"""
1621+
names = []
1622+
outputs = []
1623+
for rv in self.free_RVs:
1624+
rv_var = self.rvs_to_values[rv]
1625+
transform = getattr(rv_var.tag, "transform", None)
1626+
if transform is not None:
1627+
names.append(get_transformed_name(rv.name, transform))
1628+
outputs.append(transform.forward(rv, rv).shape)
1629+
names.append(rv.name)
1630+
outputs.append(rv.shape)
1631+
f = aesara.function(
1632+
inputs=[],
1633+
outputs=outputs,
1634+
givens=[(obs, obs.tag.observations) for obs in self.observed_RVs],
1635+
mode=aesara.compile.mode.FAST_COMPILE,
1636+
on_unused_input="ignore",
1637+
)
1638+
return {name: tuple(shape) for name, shape in zip(names, f())}
1639+
16061640
def check_start_vals(self, start):
16071641
r"""Check that the starting values for MCMC do not cause the relevant log probability
16081642
to evaluate to something invalid (e.g. Inf or NaN)

pymc/tests/test_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,28 @@ def test_soft_update_parent(self):
594594
assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"])
595595

596596

597+
class TestShapeEvaluation:
598+
def test_eval_rv_shapes(self):
599+
with pm.Model(
600+
coords={
601+
"city": ["Sydney", "Las Vegas", "Düsseldorf"],
602+
}
603+
) as pmodel:
604+
pm.Data("budget", [1, 2, 3, 4], dims="year")
605+
pm.Normal("untransformed", size=(1, 2))
606+
pm.Uniform("transformed", size=(7,))
607+
obs = pm.Uniform("observed", size=(3,), observed=[0.1, 0.2, 0.3])
608+
pm.LogNormal("lognorm", mu=at.log(obs))
609+
pm.Normal("from_dims", dims=("city", "year"))
610+
shapes = pmodel.eval_rv_shapes()
611+
assert shapes["untransformed"] == (1, 2)
612+
assert shapes["transformed"] == (7,)
613+
assert shapes["transformed_interval__"] == (7,)
614+
assert shapes["lognorm"] == (3,)
615+
assert shapes["lognorm_log__"] == (3,)
616+
assert shapes["from_dims"] == (3, 4)
617+
618+
597619
class TestCheckStartVals(SeededTest):
598620
def setup_method(self):
599621
super().setup_method()

0 commit comments

Comments
 (0)