Skip to content

Commit 3658687

Browse files
committed
Speedup random perform
1 parent 815b258 commit 3658687

File tree

1 file changed

+5
-12
lines changed
  • pytensor/tensor/random

1 file changed

+5
-12
lines changed

pytensor/tensor/random/op.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -387,24 +387,17 @@ def dist_params(self, node) -> Sequence[Variable]:
387387
return node.inputs[2:]
388388

389389
def perform(self, node, inputs, outputs):
390-
rng_var_out, smpl_out = outputs
391-
392390
rng, size, *args = inputs
393391

394392
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
395393
if not self.inplace:
396394
rng = copy(rng)
397395

398-
rng_var_out[0] = rng
399-
400-
if size is not None:
401-
size = tuple(size)
402-
smpl_val = self.rng_fn(rng, *([*args, size]))
403-
404-
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
405-
smpl_val = np.asarray(smpl_val, dtype=self.dtype)
406-
407-
smpl_out[0] = smpl_val
396+
outputs[0][0] = rng
397+
outputs[1][0] = np.asarray(
398+
self.rng_fn(rng, *args, None if size is None else tuple(size)),
399+
dtype=self.dtype,
400+
)
408401

409402
def grad(self, inputs, outputs):
410403
return [

0 commit comments

Comments
 (0)