File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change 33
33
from pytensor .scalar import Cast
34
34
from pytensor .tensor .elemwise import Elemwise
35
35
from pytensor .tensor .random .basic import IntegersRV
36
+ from pytensor .tensor .random .var import RandomGeneratorSharedVariable
36
37
from pytensor .tensor .type import TensorType
37
38
from pytensor .tensor .variable import TensorConstant , TensorVariable
38
39
@@ -148,6 +149,19 @@ def __str__(self):
148
149
return "Minibatch"
149
150
150
151
152
+ def first_inputs (r ):
153
+ if not r .owner :
154
+ return
155
+
156
+ first_input = r .owner .inputs [0 ]
157
+ yield first_input
158
+ yield from first_inputs (first_input )
159
+
160
+
161
+ def has_random_ancestor (r ):
162
+ return any (isinstance (i , RandomGeneratorSharedVariable ) for i in first_inputs (r ))
163
+
164
+
151
165
def is_valid_observed (v ) -> bool :
152
166
if not isinstance (v , Variable ):
153
167
# Non-symbolic constant
@@ -165,6 +179,7 @@ def is_valid_observed(v) -> bool:
165
179
and isinstance (v .owner .op .scalar_op , Cast )
166
180
and is_valid_observed (v .owner .inputs [0 ])
167
181
)
182
+ or not has_random_ancestor (v )
168
183
# Or Minibatch
169
184
or (
170
185
isinstance (v .owner .op , MinibatchOp )
You can’t perform that action at this time.
0 commit comments