Skip to content

Commit 7b39d7a

Browse files
committed
check for random ancestors as well
1 parent e3f66d9 commit 7b39d7a

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

pymc/data.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pytensor.scalar import Cast
3434
from pytensor.tensor.elemwise import Elemwise
3535
from pytensor.tensor.random.basic import IntegersRV
36+
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
3637
from pytensor.tensor.type import TensorType
3738
from pytensor.tensor.variable import TensorConstant, TensorVariable
3839

@@ -148,6 +149,19 @@ def __str__(self):
148149
return "Minibatch"
149150

150151

152+
def first_inputs(r):
153+
if not r.owner:
154+
return
155+
156+
first_input = r.owner.inputs[0]
157+
yield first_input
158+
yield from first_inputs(first_input)
159+
160+
161+
def has_random_ancestor(r):
162+
return any(isinstance(i, RandomGeneratorSharedVariable) for i in first_inputs(r))
163+
164+
151165
def is_valid_observed(v) -> bool:
152166
if not isinstance(v, Variable):
153167
# Non-symbolic constant
@@ -165,6 +179,7 @@ def is_valid_observed(v) -> bool:
165179
and isinstance(v.owner.op.scalar_op, Cast)
166180
and is_valid_observed(v.owner.inputs[0])
167181
)
182+
or not has_random_ancestor(v)
168183
# Or Minibatch
169184
or (
170185
isinstance(v.owner.op, MinibatchOp)

0 commit comments

Comments
 (0)