File tree Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Original file line number Diff line number Diff line change 44
44
from aesara .graph .fg import FunctionGraph
45
45
from aesara .scalar import Cast
46
46
from aesara .tensor .elemwise import Elemwise
47
+ from aesara .tensor .random .op import RandomVariable
47
48
from aesara .tensor .random .rewriting import local_subtensor_rv_lift
48
49
from aesara .tensor .sharedvar import ScalarSharedVariable
49
50
from aesara .tensor .var import TensorConstant , TensorVariable
@@ -1405,9 +1406,16 @@ def make_obs_var(
1405
1406
)
1406
1407
warnings .warn (impute_message , ImputationWarning )
1407
1408
1409
+ if not isinstance (rv_var .owner .op , RandomVariable ):
1410
+ raise NotImplementedError (
1411
+ "Automatic inputation is only supported for univariate RandomVariables."
1412
+ f" { rv_var } of type { type (rv_var .owner .op )} is not supported."
1413
+ )
1414
+
1408
1415
if rv_var .owner .op .ndim_supp > 0 :
1409
1416
raise NotImplementedError (
1410
- f"Automatic inputation is only supported for univariate RandomVariables, but { rv_var } is multivariate"
1417
+ f"Automatic inputation is only supported for univariate "
1418
+ f"RandomVariables, but { rv_var } is multivariate"
1411
1419
)
1412
1420
1413
1421
# We can get a random variable comprised of only the unobserved
Original file line number Diff line number Diff line change @@ -1377,6 +1377,19 @@ def test_dims(self):
1377
1377
x = pm .Normal ("x" , observed = data , dims = ("observed" ,))
1378
1378
assert model .RV_dims == {"x" : ("observed" ,)}
1379
1379
1380
+ def test_error_non_random_variable (self ):
1381
+ data = np .array ([np .nan ] * 3 + [0 ] * 7 )
1382
+ with pm .Model () as model :
1383
+ msg = "x of type <class 'pymc.distributions.censored.CensoredRV'> is not supported"
1384
+ with pytest .raises (NotImplementedError , match = msg ):
1385
+ x = pm .Censored (
1386
+ "x" ,
1387
+ pm .Normal .dist (),
1388
+ lower = 0 ,
1389
+ upper = 10 ,
1390
+ observed = data ,
1391
+ )
1392
+
1380
1393
1381
1394
class TestShared (SeededTest ):
1382
1395
def test_deterministic (self ):
You can’t perform that action at this time.
0 commit comments