Skip to content

Commit 04ec126

Browse files
committed
Add RandomVariable Op helpers to retrieve rng, size, and dist_params from a node, for readability
1 parent d5345d2 commit 04ec126

File tree

5 files changed

+37
-21
lines changed

5 files changed

+37
-21
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def jax_typify_Generator(rng, **kwargs):
8888

8989

9090
@jax_funcify.register(ptr.RandomVariable)
91-
def jax_funcify_RandomVariable(op, node, **kwargs):
91+
def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
9292
"""JAX implementation of random variables."""
9393
rv = node.outputs[1]
9494
out_dtype = rv.type.dtype
@@ -101,7 +101,7 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
101101
if None in static_size:
102102
# Sometimes size can be constant folded during rewrites,
103103
# without the RandomVariable node being updated with new static types
104-
size_param = node.inputs[1]
104+
size_param = op.size_param(node)
105105
if isinstance(size_param, Constant):
106106
size_tuple = tuple(size_param.data)
107107
# PyTensor uses empty size to represent size = None
@@ -304,11 +304,11 @@ def sample_fn(rng, size, dtype, df, loc, scale):
304304

305305

306306
@jax_sample_fn.register(ptr.ChoiceWithoutReplacement)
307-
def jax_funcify_choice(op, node):
307+
def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
308308
"""JAX implementation of `ChoiceRV`."""
309309

310310
batch_ndim = op.batch_ndim(node)
311-
a, *p, core_shape = node.inputs[3:]
311+
a, *p, core_shape = op.dist_params(node)
312312
a_core_ndim, *p_core_ndim, _ = op.ndims_params
313313

314314
if batch_ndim and a_core_ndim == 0:

pytensor/link/numba/dispatch/random.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,14 @@ def make_numba_random_fn(node, np_random_func):
9696
The functions generated here add parameter broadcasting and the ``size``
9797
argument to the Numba-supported scalar ``np.random`` functions.
9898
"""
99-
if not isinstance(node.inputs[0].type, RandomStateType):
99+
op: ptr.RandomVariable = node.op
100+
rng_param = op.rng_param(node)
101+
if not isinstance(rng_param.type, RandomStateType):
100102
raise TypeError("Numba does not support NumPy `Generator`s")
101103

102-
tuple_size = int(get_vector_length(node.inputs[1]))
103-
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
104+
tuple_size = int(get_vector_length(op.size_param(node)))
105+
dist_params = op.dist_params(node)
106+
size_dims = tuple_size - max(i.ndim for i in dist_params)
104107

105108
# Make a broadcast-capable version of the Numba supported scalar sampling
106109
# function
@@ -126,7 +129,7 @@ def make_numba_random_fn(node, np_random_func):
126129
)
127130

128131
bcast_fn_input_names = ", ".join(
129-
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
132+
[unique_names(i, force_unique=True) for i in dist_params]
130133
)
131134
bcast_fn_global_env = {
132135
"np_random_func": np_random_func,
@@ -143,7 +146,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
143146
)
144147

145148
random_fn_input_names = ", ".join(
146-
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
149+
["rng", "size", "dtype"] + [unique_names(i) for i in dist_params]
147150
)
148151

149152
# Now, create a Numba JITable function that implements the `size` parameter
@@ -244,7 +247,8 @@ def create_numba_random_fn(
244247
suffix_sep="_",
245248
)
246249

247-
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
250+
dist_params = op.dist_params(node)
251+
np_names = [unique_names(i, force_unique=True) for i in dist_params]
248252
np_input_names = ", ".join(np_names)
249253
np_random_fn_src = f"""
250254
@numba_vectorize
@@ -300,9 +304,9 @@ def body_fn(a):
300304

301305

302306
@numba_funcify.register(ptr.CategoricalRV)
303-
def numba_funcify_CategoricalRV(op, node, **kwargs):
307+
def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
304308
out_dtype = node.outputs[1].type.numpy_dtype
305-
size_len = int(get_vector_length(node.inputs[1]))
309+
size_len = int(get_vector_length(op.size_param(node)))
306310
p_ndim = node.inputs[-1].ndim
307311

308312
@numba_basic.numba_njit
@@ -331,9 +335,9 @@ def categorical_rv(rng, size, dtype, p):
331335
@numba_funcify.register(ptr.DirichletRV)
332336
def numba_funcify_DirichletRV(op, node, **kwargs):
333337
out_dtype = node.outputs[1].type.numpy_dtype
334-
alphas_ndim = node.inputs[3].type.ndim
338+
alphas_ndim = op.dist_params(node)[0].type.ndim
335339
neg_ind_shape_len = -alphas_ndim + 1
336-
size_len = int(get_vector_length(node.inputs[1]))
340+
size_len = int(get_vector_length(op.size_param(node)))
337341

338342
if alphas_ndim > 1:
339343

@@ -400,9 +404,9 @@ def choice_without_replacement_rv(rng, size, dtype, a, core_shape):
400404

401405

402406
@numba_funcify.register(ptr.PermutationRV)
403-
def numba_funcify_permutation(op, node, **kwargs):
407+
def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
404408
# PyTensor uses size=() to represent size=None
405-
size_is_none = node.inputs[1].type.shape == (0,)
409+
size_is_none = op.size_param(node).type.shape == (0,)
406410
batch_ndim = op.batch_ndim(node)
407411
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
408412

pytensor/tensor/random/op.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,18 @@ def make_node(self, rng, size, dtype, *dist_params):
372372
def batch_ndim(self, node: Apply) -> int:
373373
return cast(int, node.default_output().type.ndim - self.ndim_supp)
374374

375+
def rng_param(self, node) -> Variable:
376+
"""Return the node input corresponding to the rng"""
377+
return node.inputs[0]
378+
379+
def size_param(self, node) -> Variable:
380+
"""Return the node input corresponding to the size"""
381+
return node.inputs[1]
382+
383+
def dist_params(self, node) -> Sequence[Variable]:
384+
"""Return the node inpust corresponding to dist params"""
385+
return node.inputs[3:]
386+
375387
def perform(self, node, inputs, outputs):
376388
rng_var_out, smpl_out = outputs
377389

pytensor/tensor/random/rewriting/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def is_nd_advanced_idx(idx, dtype):
255255
return False
256256

257257
# Check that indexing does not act on support dims
258-
batch_ndims = rv.ndim - rv_op.ndim_supp
258+
batch_ndims = rv_op.batch_ndim(rv_node)
259259
# We decompose the boolean indexes, which makes it clear whether they act on support dims or not
260260
non_bool_indices = tuple(
261261
chain.from_iterable(

tests/tensor/random/rewriting/test_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def test_inplace_rewrites(rv_op):
111111
assert new_op._props_dict() == (op._props_dict() | {"inplace": True})
112112
assert all(
113113
np.array_equal(a.data, b.data)
114-
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
114+
for a, b in zip(new_op.dist_params(new_node), op.dist_params(node))
115115
)
116-
assert np.array_equal(new_out.owner.inputs[1].data, [])
116+
assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data)
117117

118118

119119
@config.change_flags(compute_test_value="raise")
@@ -400,7 +400,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
400400
assert new_out.owner.op == dist_op
401401
assert all(
402402
isinstance(i.owner.op, DimShuffle)
403-
for i in new_out.owner.inputs[3:]
403+
for i in new_out.owner.op.dist_params(new_out.owner)
404404
if i.owner
405405
)
406406
else:
@@ -793,7 +793,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
793793
assert isinstance(new_out.owner.op, RandomVariable)
794794
assert all(
795795
isinstance(i.owner.op, AdvancedSubtensor | AdvancedSubtensor1 | Subtensor)
796-
for i in new_out.owner.inputs[3:]
796+
for i in new_out.owner.op.dist_params(new_out.owner)
797797
if i.owner
798798
)
799799
else:

0 commit comments

Comments
 (0)