Skip to content

Commit 48c8189

Browse files
committed
Remove RandomVariable dtype input
1 parent 04ec126 commit 48c8189

File tree

6 files changed

+118
-122
lines changed

6 files changed

+118
-122
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
114114
if None in static_size:
115115
assert_size_argument_jax_compatible(node)
116116

117-
def sample_fn(rng, size, dtype, *parameters):
117+
def sample_fn(rng, size, *parameters):
118118
# PyTensor uses empty size to represent size = None
119119
if jax.numpy.asarray(size).shape == (0,):
120120
size = None
121121
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
122122

123123
else:
124124

125-
def sample_fn(rng, size, dtype, *parameters):
125+
def sample_fn(rng, size, *parameters):
126126
return jax_sample_fn(op, node=node)(
127127
rng, static_size, out_dtype, *parameters
128128
)

pytensor/link/numba/dispatch/random.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def make_numba_random_fn(node, np_random_func):
123123
"size_dims",
124124
"rng",
125125
"size",
126-
"dtype",
127126
],
128127
suffix_sep="_",
129128
)
@@ -146,7 +145,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
146145
)
147146

148147
random_fn_input_names = ", ".join(
149-
["rng", "size", "dtype"] + [unique_names(i) for i in dist_params]
148+
["rng", "size"] + [unique_names(i) for i in dist_params]
150149
)
151150

152151
# Now, create a Numba JITable function that implements the `size` parameter
@@ -243,7 +242,7 @@ def create_numba_random_fn(
243242
np_global_env["numba_vectorize"] = numba_basic.numba_vectorize
244243

245244
unique_names = unique_name_generator(
246-
[np_random_fn_name, *np_global_env.keys(), "rng", "size", "dtype"],
245+
[np_random_fn_name, *np_global_env.keys(), "rng", "size"],
247246
suffix_sep="_",
248247
)
249248

@@ -310,7 +309,7 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
310309
p_ndim = node.inputs[-1].ndim
311310

312311
@numba_basic.numba_njit
313-
def categorical_rv(rng, size, dtype, p):
312+
def categorical_rv(rng, size, p):
314313
if not size_len:
315314
size_tpl = p.shape[:-1]
316315
else:
@@ -342,7 +341,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
342341
if alphas_ndim > 1:
343342

344343
@numba_basic.numba_njit
345-
def dirichlet_rv(rng, size, dtype, alphas):
344+
def dirichlet_rv(rng, size, alphas):
346345
if size_len > 0:
347346
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
348347
if (
@@ -365,7 +364,7 @@ def dirichlet_rv(rng, size, dtype, alphas):
365364
else:
366365

367366
@numba_basic.numba_njit
368-
def dirichlet_rv(rng, size, dtype, alphas):
367+
def dirichlet_rv(rng, size, alphas):
369368
size = numba_ndarray.to_fixed_tuple(size, size_len)
370369
return (rng, np.random.dirichlet(alphas, size))
371370

@@ -388,14 +387,14 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
388387
if op.has_p_param:
389388

390389
@numba_basic.numba_njit
391-
def choice_without_replacement_rv(rng, size, dtype, a, p, core_shape):
390+
def choice_without_replacement_rv(rng, size, a, p, core_shape):
392391
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
393392
samples = np.random.choice(a, size=core_shape, replace=False, p=p)
394393
return (rng, samples)
395394
else:
396395

397396
@numba_basic.numba_njit
398-
def choice_without_replacement_rv(rng, size, dtype, a, core_shape):
397+
def choice_without_replacement_rv(rng, size, a, core_shape):
399398
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
400399
samples = np.random.choice(a, size=core_shape, replace=False)
401400
return (rng, samples)
@@ -411,7 +410,7 @@ def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
411410
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
412411

413412
@numba_basic.numba_njit
414-
def permutation_rv(rng, size, dtype, x):
413+
def permutation_rv(rng, size, x):
415414
if batch_ndim:
416415
x_core_shape = x.shape[x_batch_ndim:]
417416
if size_is_none:

pytensor/tensor/random/op.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
normalize_size_param,
2828
)
2929
from pytensor.tensor.shape import shape_tuple
30-
from pytensor.tensor.type import TensorType, all_dtypes
30+
from pytensor.tensor.type import TensorType
3131
from pytensor.tensor.type_other import NoneConst
3232
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
3333
from pytensor.tensor.variable import TensorVariable
@@ -65,7 +65,7 @@ def __init__(
6565
signature: str
6666
Numpy-like vectorized signature of the random variable.
6767
dtype: str (optional)
68-
The dtype of the sampled output. If the value ``"floatX"`` is
68+
The default dtype of the sampled output. If the value ``"floatX"`` is
6969
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
7070
``None`` (the default), the `dtype` keyword must be set when
7171
`RandomVariable.make_node` is called.
@@ -287,8 +287,8 @@ def extract_batch_shape(p, ps, n):
287287
return shape
288288

289289
def infer_shape(self, fgraph, node, input_shapes):
290-
_, size, _, *dist_params = node.inputs
291-
_, size_shape, _, *param_shapes = input_shapes
290+
_, size, *dist_params = node.inputs
291+
_, size_shape, *param_shapes = input_shapes
292292

293293
try:
294294
size_len = get_vector_length(size)
@@ -302,14 +302,34 @@ def infer_shape(self, fgraph, node, input_shapes):
302302
return [None, list(shape)]
303303

304304
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
305-
res = super().__call__(rng, size, dtype, *args, **kwargs)
305+
if dtype is None:
306+
dtype = self.dtype
307+
if dtype == "floatX":
308+
dtype = config.floatX
309+
310+
# We need to recreate the Op with the right dtype
311+
if dtype != self.dtype:
312+
# Check we are not switching from float to int
313+
if self.dtype is not None:
314+
if dtype.startswith("float") != self.dtype.startswith("float"):
315+
raise ValueError(
316+
f"Cannot change the dtype of a {self.name} RV from {self.dtype} to {dtype}"
317+
)
318+
props = self._props_dict()
319+
props["dtype"] = dtype
320+
new_op = type(self)(**props)
321+
return new_op.__call__(
322+
*args, size=size, name=name, rng=rng, dtype=dtype, **kwargs
323+
)
324+
325+
res = super().__call__(rng, size, *args, **kwargs)
306326

307327
if name is not None:
308328
res.name = name
309329

310330
return res
311331

312-
def make_node(self, rng, size, dtype, *dist_params):
332+
def make_node(self, rng, size, *dist_params):
313333
"""Create a random variable node.
314334
315335
Parameters
@@ -349,23 +369,10 @@ def make_node(self, rng, size, dtype, *dist_params):
349369

350370
shape = self._infer_shape(size, dist_params)
351371
_, static_shape = infer_static_shape(shape)
352-
dtype = self.dtype or dtype
353372

354-
if dtype == "floatX":
355-
dtype = config.floatX
356-
elif dtype is None or (isinstance(dtype, str) and dtype not in all_dtypes):
357-
raise TypeError("dtype is unspecified")
358-
359-
if isinstance(dtype, str):
360-
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
361-
else:
362-
dtype_idx = constant(dtype, dtype="int64")
363-
364-
dtype = all_dtypes[dtype_idx.data]
365-
366-
inputs = (rng, size, dtype_idx, *dist_params)
367-
out_var = TensorType(dtype=dtype, shape=static_shape)()
368-
outputs = (rng.type(), out_var)
373+
inputs = (rng, size, *dist_params)
374+
out_type = TensorType(dtype=self.dtype, shape=static_shape)
375+
outputs = (rng.type(), out_type())
369376

370377
return Apply(self, inputs, outputs)
371378

@@ -382,14 +389,12 @@ def size_param(self, node) -> Variable:
382389

383390
def dist_params(self, node) -> Sequence[Variable]:
384391
"""Return the node inpust corresponding to dist params"""
385-
return node.inputs[3:]
392+
return node.inputs[2:]
386393

387394
def perform(self, node, inputs, outputs):
388395
rng_var_out, smpl_out = outputs
389396

390-
rng, size, dtype, *args = inputs
391-
392-
out_var = node.outputs[1]
397+
rng, size, *args = inputs
393398

394399
# If `size == []`, that means no size is enforced, and NumPy is trusted
395400
# to draw the appropriate number of samples, NumPy uses `size=None` to
@@ -408,11 +413,8 @@ def perform(self, node, inputs, outputs):
408413

409414
smpl_val = self.rng_fn(rng, *([*args, size]))
410415

411-
if (
412-
not isinstance(smpl_val, np.ndarray)
413-
or str(smpl_val.dtype) != out_var.type.dtype
414-
):
415-
smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype)
416+
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
417+
smpl_val = _asarray(smpl_val, dtype=self.dtype)
416418

417419
smpl_out[0] = smpl_val
418420

@@ -463,7 +465,7 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor):
463465

464466
@_vectorize_node.register(RandomVariable)
465467
def vectorize_random_variable(
466-
op: RandomVariable, node: Apply, rng, size, dtype, *dist_params
468+
op: RandomVariable, node: Apply, rng, size, *dist_params
467469
) -> Apply:
468470
# If size was provided originally and a new size hasn't been provided,
469471
# We extend it to accommodate the new input batch dimensions.
@@ -491,4 +493,4 @@ def vectorize_random_variable(
491493
new_size_dims = broadcasted_batch_shape[:new_ndim]
492494
size = concatenate([new_size_dims, size])
493495

494-
return op.make_node(rng, size, dtype, *dist_params)
496+
return op.make_node(rng, size, *dist_params)

pytensor/tensor/random/rewriting/basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node):
8181
if not isinstance(node.op, RandomVariable):
8282
return
8383

84-
rng, size, dtype, *dist_params = node.inputs
84+
rng, size, *dist_params = node.inputs
8585

8686
dist_params = broadcast_params(dist_params, node.op.ndims_params)
8787

@@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node):
105105
else:
106106
return
107107

108-
new_node = node.op.make_node(rng, None, dtype, *dist_params)
108+
new_node = node.op.make_node(rng, None, *dist_params)
109109

110110
if config.compute_test_value != "off":
111111
compute_test_value(new_node)
@@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
141141
return False
142142

143143
rv_op = rv_node.op
144-
rng, size, dtype, *dist_params = rv_node.inputs
144+
rng, size, *dist_params = rv_node.inputs
145145
rv = rv_node.default_output()
146146

147147
# Check that Dimshuffle does not affect support dims
@@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
185185
)
186186
new_dist_params.append(param.dimshuffle(param_new_order))
187187

188-
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
188+
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
189189

190190
if config.compute_test_value != "off":
191191
compute_test_value(new_node)
@@ -233,7 +233,7 @@ def is_nd_advanced_idx(idx, dtype):
233233
return None
234234

235235
rv_op = rv_node.op
236-
rng, size, dtype, *dist_params = rv_node.inputs
236+
rng, size, *dist_params = rv_node.inputs
237237

238238
# Parse indices
239239
idx_list = getattr(subtensor_op, "idx_list", None)
@@ -346,7 +346,7 @@ def is_nd_advanced_idx(idx, dtype):
346346
new_dist_params.append(batch_param[tuple(batch_indices)])
347347

348348
# Create new RV
349-
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
349+
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
350350
new_rv = new_node.default_output()
351351

352352
copy_stack_trace(rv, new_rv)

tests/tensor/random/rewriting/test_basic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.tensor import constant
1313
from pytensor.tensor.elemwise import DimShuffle
1414
from pytensor.tensor.random.basic import (
15+
NormalRV,
1516
categorical,
1617
dirichlet,
1718
multinomial,
@@ -397,7 +398,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
397398
)
398399

399400
if lifted:
400-
assert new_out.owner.op == dist_op
401+
assert isinstance(new_out.owner.op, type(dist_op))
401402
assert all(
402403
isinstance(i.owner.op, DimShuffle)
403404
for i in new_out.owner.op.dist_params(new_out.owner)
@@ -832,7 +833,7 @@ def test_Subtensor_lift_restrictions():
832833
subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
833834
assert subtensor_node == y.owner
834835
assert isinstance(subtensor_node.op, Subtensor)
835-
assert subtensor_node.inputs[0].owner.op == normal
836+
assert isinstance(subtensor_node.inputs[0].owner.op, NormalRV)
836837

837838
z = pt.ones(x.shape) - x[1]
838839

@@ -850,7 +851,7 @@ def test_Subtensor_lift_restrictions():
850851
EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
851852

852853
rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
853-
assert rv_node.op == normal
854+
assert isinstance(rv_node.op, NormalRV)
854855
assert isinstance(rv_node.inputs[-1].owner.op, Subtensor)
855856
assert isinstance(rv_node.inputs[-2].owner.op, Subtensor)
856857

@@ -872,7 +873,7 @@ def test_Dimshuffle_lift_restrictions():
872873
dimshuffle_node = fg.outputs[0].owner.inputs[1].owner
873874
assert dimshuffle_node == y.owner
874875
assert isinstance(dimshuffle_node.op, DimShuffle)
875-
assert dimshuffle_node.inputs[0].owner.op == normal
876+
assert isinstance(dimshuffle_node.inputs[0].owner.op, NormalRV)
876877

877878
z = pt.ones(x.shape) - y
878879

@@ -890,7 +891,7 @@ def test_Dimshuffle_lift_restrictions():
890891
EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
891892

892893
rv_node = fg.outputs[0].owner.inputs[1].owner
893-
assert rv_node.op == normal
894+
assert isinstance(rv_node.op, NormalRV)
894895
assert isinstance(rv_node.inputs[-1].owner.op, DimShuffle)
895896
assert isinstance(rv_node.inputs[-2].owner.op, DimShuffle)
896897

0 commit comments

Comments
 (0)