|
62 | 62 | from pymc.distributions.transforms import Transform
|
63 | 63 | from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
|
64 | 64 | from pymc.math import flatten_list
|
65 |
| -from pymc.util import UNSET, WithMemoization, get_var_name, treedict, treelist |
| 65 | +from pymc.util import ( |
| 66 | + UNSET, |
| 67 | + WithMemoization, |
| 68 | + get_transformed_name, |
| 69 | + get_var_name, |
| 70 | + treedict, |
| 71 | + treelist, |
| 72 | +) |
66 | 73 | from pymc.vartypes import continuous_types, discrete_types, typefilter
|
67 | 74 |
|
68 | 75 | __all__ = [
|
@@ -1603,6 +1610,33 @@ def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]):
|
1603 | 1610 |
|
1604 | 1611 | a.update({k: v for k, v in b.items() if k not in a})
|
1605 | 1612 |
|
| 1613 | + def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]: |
| 1614 | + """Evaluates shapes of untransformed AND transformed free variables. |
| 1615 | +
|
| 1616 | + Returns |
| 1617 | + ------- |
| 1618 | + shapes : dict |
| 1619 | + Maps untransformed and transformed variable names to shape tuples. |
| 1620 | + """ |
| 1621 | + names = [] |
| 1622 | + outputs = [] |
| 1623 | + for rv in self.free_RVs: |
| 1624 | + rv_var = self.rvs_to_values[rv] |
| 1625 | + transform = getattr(rv_var.tag, "transform", None) |
| 1626 | + if transform is not None: |
| 1627 | + names.append(get_transformed_name(rv.name, transform)) |
| 1628 | + outputs.append(transform.forward(rv, rv).shape) |
| 1629 | + names.append(rv.name) |
| 1630 | + outputs.append(rv.shape) |
| 1631 | + f = aesara.function( |
| 1632 | + inputs=[], |
| 1633 | + outputs=outputs, |
| 1634 | + givens=[(obs, obs.tag.observations) for obs in self.observed_RVs], |
| 1635 | + mode=aesara.compile.mode.FAST_COMPILE, |
| 1636 | + on_unused_input="ignore", |
| 1637 | + ) |
| 1638 | + return {name: tuple(shape) for name, shape in zip(names, f())} |
| 1639 | + |
1606 | 1640 | def check_start_vals(self, start):
|
1607 | 1641 | r"""Check that the starting values for MCMC do not cause the relevant log probability
|
1608 | 1642 | to evaluate to something invalid (e.g. Inf or NaN)
|
|
0 commit comments