Skip to content

Commit 413af04

Browse files
committed
Allow check_dist_not_registered to be called inside CustomDist
1 parent 4acd5a3 commit 413af04

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

pymc/util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from pytensor.compile import SharedVariable
2828
from pytensor.graph.utils import ValidatingScratchpad
2929

30+
from pymc.exceptions import BlockModelAccessError
31+
3032

3133
class _UnsetType:
3234
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
@@ -367,7 +369,7 @@ def check_dist_not_registered(dist, model=None):
367369

368370
try:
369371
model = modelcontext(None)
370-
except TypeError:
372+
except (TypeError, BlockModelAccessError):
371373
pass
372374
else:
373375
if dist in model.basic_RVs:

tests/distributions/test_distribution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,13 @@ def shifted_normal(mu, sigma, size):
590590
pm.logp(pm.NormalMixture.dist(w=w, mu=mus, sigma=sds), test_value).eval(),
591591
)
592592

593+
def test_symbolic_dist(self):
594+
# Test we can create a SymbolicDist inside a CustomDist
595+
def dist(size):
596+
return pm.Truncated.dist(pm.Beta.dist(1, 1, size=size), lower=0.1, upper=0.9)
597+
598+
assert pm.CustomDist.dist(dist=dist)
599+
593600

594601
class TestSymbolicRandomVariable:
595602
def test_inline(self):

0 commit comments

Comments
 (0)