Skip to content

Commit f81ffd8

Browse files
committed
Distinguish between size=None and size=() in RandomVariables
1 parent 48c8189 commit f81ffd8

File tree

10 files changed

+163
-152
lines changed

10 files changed

+163
-152
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
1313
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
1414
from pytensor.tensor.shape import Shape, Shape_i
15+
from pytensor.tensor.type_other import NoneTypeT
1516

1617

1718
try:
@@ -93,7 +94,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
9394
rv = node.outputs[1]
9495
out_dtype = rv.type.dtype
9596
static_shape = rv.type.shape
96-
9797
batch_ndim = op.batch_ndim(node)
9898

9999
# Try to pass static size directly to JAX
@@ -102,11 +102,10 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
102102
# Sometimes size can be constant folded during rewrites,
103103
# without the RandomVariable node being updated with new static types
104104
size_param = op.size_param(node)
105-
if isinstance(size_param, Constant):
106-
size_tuple = tuple(size_param.data)
107-
# PyTensor uses empty size to represent size = None
108-
if len(size_tuple):
109-
static_size = tuple(size_param.data)
105+
if isinstance(size_param, Constant) and not isinstance(
106+
size_param.type, NoneTypeT
107+
):
108+
static_size = tuple(size_param.data)
110109

111110
# If one dimension has unknown size, either the size is determined
112111
# by a `Shape` operator in which case JAX will compile, or it is
@@ -115,9 +114,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
115114
assert_size_argument_jax_compatible(node)
116115

117116
def sample_fn(rng, size, *parameters):
118-
# PyTensor uses empty size to represent size = None
119-
if jax.numpy.asarray(size).shape == (0,):
120-
size = None
121117
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
122118

123119
else:

pytensor/link/numba/dispatch/random.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from pytensor.tensor.basic import get_vector_length
2323
from pytensor.tensor.random.type import RandomStateType
24+
from pytensor.tensor.type_other import NoneTypeT
2425

2526

2627
class RandomStateNumbaType(types.Type):
@@ -101,9 +102,13 @@ def make_numba_random_fn(node, np_random_func):
101102
if not isinstance(rng_param.type, RandomStateType):
102103
raise TypeError("Numba does not support NumPy `Generator`s")
103104

104-
tuple_size = int(get_vector_length(op.size_param(node)))
105+
size_param = op.size_param(node)
106+
size_len = (
107+
None
108+
if isinstance(size_param.type, NoneTypeT)
109+
else int(get_vector_length(size_param))
110+
)
105111
dist_params = op.dist_params(node)
106-
size_dims = tuple_size - max(i.ndim for i in dist_params)
107112

108113
# Make a broadcast-capable version of the Numba supported scalar sampling
109114
# function
@@ -119,7 +124,7 @@ def make_numba_random_fn(node, np_random_func):
119124
"np_random_func",
120125
"numba_vectorize",
121126
"to_fixed_tuple",
122-
"tuple_size",
127+
"size_len",
123128
"size_dims",
124129
"rng",
125130
"size",
@@ -155,10 +160,12 @@ def {bcast_fn_name}({bcast_fn_input_names}):
155160
"out_dtype": out_dtype,
156161
}
157162

158-
if tuple_size > 0:
163+
if size_len is not None:
164+
size_dims = size_len - max(i.ndim for i in dist_params)
165+
159166
random_fn_body = dedent(
160167
f"""
161-
size = to_fixed_tuple(size, tuple_size)
168+
size = to_fixed_tuple(size, size_len)
162169
163170
data = np.empty(size, dtype=out_dtype)
164171
for i in np.ndindex(size[:size_dims]):
@@ -170,7 +177,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
170177
{
171178
"np": np,
172179
"to_fixed_tuple": numba_ndarray.to_fixed_tuple,
173-
"tuple_size": tuple_size,
180+
"size_len": size_len,
174181
"size_dims": size_dims,
175182
}
176183
)
@@ -305,19 +312,24 @@ def body_fn(a):
305312
@numba_funcify.register(ptr.CategoricalRV)
306313
def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
307314
out_dtype = node.outputs[1].type.numpy_dtype
308-
size_len = int(get_vector_length(op.size_param(node)))
315+
size_param = op.size_param(node)
316+
size_len = (
317+
None
318+
if isinstance(size_param.type, NoneTypeT)
319+
else int(get_vector_length(size_param))
320+
)
309321
p_ndim = node.inputs[-1].ndim
310322

311323
@numba_basic.numba_njit
312324
def categorical_rv(rng, size, p):
313-
if not size_len:
325+
if size_len is None:
314326
size_tpl = p.shape[:-1]
315327
else:
316328
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
317329
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
318330

319331
# Workaround https://github.com/numba/numba/issues/8975
320-
if not size_len and p_ndim == 1:
332+
if size_len is None and p_ndim == 1:
321333
unif_samples = np.asarray(np.random.uniform(0, 1))
322334
else:
323335
unif_samples = np.random.uniform(0, 1, size_tpl)
@@ -336,22 +348,27 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
336348
out_dtype = node.outputs[1].type.numpy_dtype
337349
alphas_ndim = op.dist_params(node)[0].type.ndim
338350
neg_ind_shape_len = -alphas_ndim + 1
339-
size_len = int(get_vector_length(op.size_param(node)))
351+
size_param = op.size_param(node)
352+
size_len = (
353+
None
354+
if isinstance(size_param.type, NoneTypeT)
355+
else int(get_vector_length(size_param))
356+
)
340357

341358
if alphas_ndim > 1:
342359

343360
@numba_basic.numba_njit
344361
def dirichlet_rv(rng, size, alphas):
345-
if size_len > 0:
362+
if size_len is None:
363+
samples_shape = alphas.shape
364+
else:
346365
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
347366
if (
348367
0 < alphas.ndim - 1 <= len(size_tpl)
349368
and size_tpl[neg_ind_shape_len:] != alphas.shape[:-1]
350369
):
351370
raise ValueError("Parameters shape and size do not match.")
352371
samples_shape = size_tpl + alphas.shape[-1:]
353-
else:
354-
samples_shape = alphas.shape
355372

356373
res = np.empty(samples_shape, dtype=out_dtype)
357374
alphas_bcast = np.broadcast_to(alphas, samples_shape)
@@ -365,7 +382,8 @@ def dirichlet_rv(rng, size, alphas):
365382

366383
@numba_basic.numba_njit
367384
def dirichlet_rv(rng, size, alphas):
368-
size = numba_ndarray.to_fixed_tuple(size, size_len)
385+
if size_len is not None:
386+
size = numba_ndarray.to_fixed_tuple(size, size_len)
369387
return (rng, np.random.dirichlet(alphas, size))
370388

371389
return dirichlet_rv
@@ -404,8 +422,7 @@ def choice_without_replacement_rv(rng, size, a, core_shape):
404422

405423
@numba_funcify.register(ptr.PermutationRV)
406424
def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
407-
# PyTensor uses size=() to represent size=None
408-
size_is_none = op.size_param(node).type.shape == (0,)
425+
size_is_none = isinstance(op.size_param(node).type, NoneTypeT)
409426
batch_ndim = op.batch_ndim(node)
410427
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
411428

pytensor/tensor/random/basic.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -914,12 +914,11 @@ def rng_fn(cls, rng, mean, cov, size):
914914
# multivariate normals (or any other multivariate distributions),
915915
# so we need to implement that here
916916

917-
size = tuple(size or ())
918-
if size:
917+
if size is None:
918+
mean, cov = broadcast_params([mean, cov], [1, 2])
919+
else:
919920
mean = np.broadcast_to(mean, size + mean.shape[-1:])
920921
cov = np.broadcast_to(cov, size + cov.shape[-2:])
921-
else:
922-
mean, cov = broadcast_params([mean, cov], [1, 2])
923922

924923
res = np.empty(mean.shape)
925924
for idx in np.ndindex(mean.shape[:-1]):
@@ -1800,13 +1799,11 @@ def __call__(self, n, p, size=None, **kwargs):
18001799
@classmethod
18011800
def rng_fn(cls, rng, n, p, size):
18021801
if n.ndim > 0 or p.ndim > 1:
1803-
size = tuple(size or ())
1804-
1805-
if size:
1802+
if size is None:
1803+
n, p = broadcast_params([n, p], [0, 1])
1804+
else:
18061805
n = np.broadcast_to(n, size)
18071806
p = np.broadcast_to(p, size + p.shape[-1:])
1808-
else:
1809-
n, p = broadcast_params([n, p], [0, 1])
18101807

18111808
res = np.empty(p.shape, dtype=cls.dtype)
18121809
for idx in np.ndindex(p.shape[:-1]):
@@ -2155,7 +2152,7 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
21552152
def rng_fn(self, rng, x, size):
21562153
# We don't have access to the node in rng_fn :(
21572154
x_batch_ndim = x.ndim - self.ndims_params[0]
2158-
batch_ndim = max(x_batch_ndim, len(size or ()))
2155+
batch_ndim = max(x_batch_ndim, 0 if size is None else len(size))
21592156

21602157
if batch_ndim:
21612158
# rng.permutation has no concept of batch dims

pytensor/tensor/random/op.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
as_tensor_variable,
1717
concatenate,
1818
constant,
19-
get_underlying_scalar_constant_value,
2019
get_vector_length,
2120
infer_static_shape,
2221
)
@@ -28,7 +27,7 @@
2827
)
2928
from pytensor.tensor.shape import shape_tuple
3029
from pytensor.tensor.type import TensorType
31-
from pytensor.tensor.type_other import NoneConst
30+
from pytensor.tensor.type_other import NoneConst, NoneTypeT
3231
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
3332
from pytensor.tensor.variable import TensorVariable
3433

@@ -196,10 +195,10 @@ def __str__(self):
196195

197196
def _infer_shape(
198197
self,
199-
size: TensorVariable,
198+
size: TensorVariable | Variable,
200199
dist_params: Sequence[TensorVariable],
201200
param_shapes: Sequence[tuple[Variable, ...]] | None = None,
202-
) -> TensorVariable | tuple[ScalarVariable, ...]:
201+
) -> tuple[ScalarVariable | TensorVariable, ...]:
203202
"""Compute the output shape given the size and distribution parameters.
204203
205204
Parameters
@@ -225,9 +224,9 @@ def _infer_shape(
225224
self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
226225
)
227226

228-
size_len = get_vector_length(size)
227+
if not isinstance(size.type, NoneTypeT):
228+
size_len = get_vector_length(size)
229229

230-
if size_len > 0:
231230
# Fail early when size is incompatible with parameters
232231
for i, (param, param_ndim_supp) in enumerate(
233232
zip(dist_params, self.ndims_params)
@@ -281,21 +280,11 @@ def extract_batch_shape(p, ps, n):
281280

282281
shape = batch_shape + supp_shape
283282

284-
if not shape:
285-
shape = constant([], dtype="int64")
286-
287283
return shape
288284

289285
def infer_shape(self, fgraph, node, input_shapes):
290286
_, size, *dist_params = node.inputs
291-
_, size_shape, *param_shapes = input_shapes
292-
293-
try:
294-
size_len = get_vector_length(size)
295-
except ValueError:
296-
size_len = get_underlying_scalar_constant_value(size_shape[0])
297-
298-
size = tuple(size[n] for n in range(size_len))
287+
_, _, *param_shapes = input_shapes
299288

300289
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
301290

@@ -367,8 +356,8 @@ def make_node(self, rng, size, *dist_params):
367356
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
368357
)
369358

370-
shape = self._infer_shape(size, dist_params)
371-
_, static_shape = infer_static_shape(shape)
359+
inferred_shape = self._infer_shape(size, dist_params)
360+
_, static_shape = infer_static_shape(inferred_shape)
372361

373362
inputs = (rng, size, *dist_params)
374363
out_type = TensorType(dtype=self.dtype, shape=static_shape)
@@ -396,21 +385,14 @@ def perform(self, node, inputs, outputs):
396385

397386
rng, size, *args = inputs
398387

399-
# If `size == []`, that means no size is enforced, and NumPy is trusted
400-
# to draw the appropriate number of samples, NumPy uses `size=None` to
401-
# represent that. Otherwise, NumPy expects a tuple.
402-
if np.size(size) == 0:
403-
size = None
404-
else:
405-
size = tuple(size)
406-
407-
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng`
408-
# otherwise.
388+
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
409389
if not self.inplace:
410390
rng = copy(rng)
411391

412392
rng_var_out[0] = rng
413393

394+
if size is not None:
395+
size = tuple(size)
414396
smpl_val = self.rng_fn(rng, *([*args, size]))
415397

416398
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
@@ -473,7 +455,9 @@ def vectorize_random_variable(
473455

474456
original_dist_params = op.dist_params(node)
475457
old_size = op.size_param(node)
476-
len_old_size = get_vector_length(old_size)
458+
len_old_size = (
459+
None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size)
460+
)
477461

478462
original_expanded_dist_params = explicit_expand_dims(
479463
original_dist_params, op.ndims_params, len_old_size

0 commit comments

Comments
 (0)