Skip to content

Commit 9c313cb

Browse files
committed
Add informative error when attempting imputation on non-pure RandomVariables
1 parent ff7c2ff commit 9c313cb

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

pymc/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from aesara.graph.fg import FunctionGraph
4545
from aesara.scalar import Cast
4646
from aesara.tensor.elemwise import Elemwise
47+
from aesara.tensor.random.op import RandomVariable
4748
from aesara.tensor.random.rewriting import local_subtensor_rv_lift
4849
from aesara.tensor.sharedvar import ScalarSharedVariable
4950
from aesara.tensor.var import TensorConstant, TensorVariable
@@ -1405,9 +1406,16 @@ def make_obs_var(
14051406
)
14061407
warnings.warn(impute_message, ImputationWarning)
14071408

1409+
if not isinstance(rv_var.owner.op, RandomVariable):
1410+
raise NotImplementedError(
1411+
"Automatic inputation is only supported for univariate RandomVariables."
1412+
f" {rv_var} of type {type(rv_var.owner.op)} is not supported."
1413+
)
1414+
14081415
if rv_var.owner.op.ndim_supp > 0:
14091416
raise NotImplementedError(
1410-
f"Automatic inputation is only supported for univariate RandomVariables, but {rv_var} is multivariate"
1417+
f"Automatic inputation is only supported for univariate "
1418+
f"RandomVariables, but {rv_var} is multivariate"
14111419
)
14121420

14131421
# We can get a random variable comprised of only the unobserved

pymc/tests/test_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,19 @@ def test_dims(self):
13771377
x = pm.Normal("x", observed=data, dims=("observed",))
13781378
assert model.RV_dims == {"x": ("observed",)}
13791379

1380+
def test_error_non_random_variable(self):
1381+
data = np.array([np.nan] * 3 + [0] * 7)
1382+
with pm.Model() as model:
1383+
msg = "x of type <class 'pymc.distributions.censored.CensoredRV'> is not supported"
1384+
with pytest.raises(NotImplementedError, match=msg):
1385+
x = pm.Censored(
1386+
"x",
1387+
pm.Normal.dist(),
1388+
lower=0,
1389+
upper=10,
1390+
observed=data,
1391+
)
1392+
13801393

13811394
class TestShared(SeededTest):
13821395
def test_deterministic(self):

0 commit comments

Comments
 (0)