File tree 1 file changed +5
-12
lines changed
1 file changed +5
-12
lines changed Original file line number Diff line number Diff line change @@ -387,24 +387,17 @@ def dist_params(self, node) -> Sequence[Variable]:
387
387
return node .inputs [2 :]
388
388
389
389
def perform (self , node , inputs , outputs ):
390
- rng_var_out , smpl_out = outputs
391
-
392
390
rng , size , * args = inputs
393
391
394
392
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
395
393
if not self .inplace :
396
394
rng = copy (rng )
397
395
398
- rng_var_out [0 ] = rng
399
-
400
- if size is not None :
401
- size = tuple (size )
402
- smpl_val = self .rng_fn (rng , * ([* args , size ]))
403
-
404
- if not isinstance (smpl_val , np .ndarray ) or str (smpl_val .dtype ) != self .dtype :
405
- smpl_val = np .asarray (smpl_val , dtype = self .dtype )
406
-
407
- smpl_out [0 ] = smpl_val
396
+ outputs [0 ][0 ] = rng
397
+ outputs [1 ][0 ] = np .asarray (
398
+ self .rng_fn (rng , * args , None if size is None else tuple (size )),
399
+ dtype = self .dtype ,
400
+ )
408
401
409
402
def grad (self , inputs , outputs ):
410
403
return [
You can’t perform that action at this time.
0 commit comments