Skip to content

Commit 9aa89ee

Browse files
committed
Fail early in RandomVariable.make_node when size is incompatible with parameters dimensionality
1 parent df2f8a5 commit 9aa89ee

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

pytensor/tensor/random/op.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,23 @@ def make_node(self, rng, size, dtype, *dist_params):
319319
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
320320
)
321321

322+
# Fail early when size is incompatible with parameters
323+
size_len = get_vector_length(size)
324+
if size_len:
325+
for i, (param, param_ndim_supp) in enumerate(
326+
zip(dist_params, self.ndims_params)
327+
):
328+
param_batched_dims = param.ndim - param_ndim_supp
329+
if param_batched_dims > size_len:
330+
raise ValueError(
331+
f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n"
332+
f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. "
333+
f"Size length must be 0 or >= {param_batched_dims}"
334+
)
335+
322336
shape = self._infer_shape(size, dist_params)
323337
_, static_shape = infer_static_shape(shape)
338+
324339
dtype = self.dtype or dtype
325340

326341
if dtype == "floatX":

tests/tensor/random/test_op.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,17 @@ def test_random_maker_ops_no_seed():
217217
z = function(inputs=[], outputs=[default_rng()])()
218218
aes_res = z[0]
219219
assert isinstance(aes_res, np.random.Generator)
220+
221+
222+
def test_RandomVariable_incompatible_size():
223+
rv_op = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
224+
with pytest.raises(
225+
ValueError, match="Size length is incompatible with batched dimensions"
226+
):
227+
rv_op(np.zeros((1, 3)), 1, size=(3,))
228+
229+
rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True)
230+
with pytest.raises(
231+
ValueError, match="Size length is incompatible with batched dimensions"
232+
):
233+
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))

0 commit comments

Comments
 (0)