Skip to content

Commit 7bcb2bf

Browse files
committed
Ignore constant shape dependencies in MarginalModel
1 parent 644397b commit 7bcb2bf

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

pymc_experimental/marginal_model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
__all__ = ["MarginalModel"]
2323

24+
from pytensor.tensor.shape import Shape
25+
2426

2527
class MarginalModel(Model):
2628
"""Subclass of PyMC Model that implements functionality for automatic
@@ -251,9 +253,24 @@ class FiniteDiscreteMarginalRV(MarginalRV):
251253
"""Base class for Finite Discrete Marginalized RVs"""
252254

253255

256+
def static_shape_ancestors(vars):
257+
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
258+
return [
259+
var
260+
for var in ancestors(vars)
261+
if (
262+
var.owner
263+
and isinstance(var.owner.op, Shape)
264+
# All static dims lengths of Shape input are known
265+
and None not in var.owner.inputs[0].type.shape
266+
)
267+
]
268+
269+
254270
def find_conditional_input_rvs(output_rvs, all_rvs):
255-
"""Find conditionally indepedent input RVs"""
271+
"""Find conditionally indepedent input RVs."""
256272
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
273+
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
257274
return [
258275
var
259276
for var in ancestors(output_rvs, blockers=blockers)

pymc_experimental/tests/test_marginal_model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from pymc.util import UNSET
1313
from scipy.special import logsumexp
1414

15-
from pymc_experimental.marginal_model import FiniteDiscreteMarginalRV, MarginalModel
15+
from pymc_experimental.marginal_model import (
16+
FiniteDiscreteMarginalRV,
17+
MarginalModel,
18+
is_conditional_dependent,
19+
)
1620

1721

1822
@pytest.fixture
@@ -411,3 +415,14 @@ def test_marginalized_transforms(transform, expected_warning):
411415
transform_name = transform.name
412416
assert f"sigma_{transform_name}__" in ip
413417
np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip))
418+
419+
420+
def test_is_conditional_dependent_static_shape():
421+
"""Test that we don't consider dependencies through "constant" shape Ops"""
422+
x1 = pt.matrix("x1", shape=(None, 5))
423+
y1 = pt.random.normal(size=pt.shape(x1))
424+
assert is_conditional_dependent(y1, x1, [x1, y1])
425+
426+
x2 = pt.matrix("x2", shape=(9, 5))
427+
y2 = pt.random.normal(size=pt.shape(x2))
428+
assert not is_conditional_dependent(y2, x2, [x2, y2])

0 commit comments

Comments
 (0)