@@ -3437,20 +3437,18 @@ class PermuteRowElements(Op):
3437
3437
permutation instead.
3438
3438
"""
3439
3439
3440
- __props__ = ()
3440
+ __props__ = ("inverse" ,)
3441
+
3442
+ def __init__ (self , inverse : bool ):
3443
+ super ().__init__ ()
3444
+ self .inverse = inverse
3441
3445
3442
- def make_node (self , x , y , inverse ):
3446
+ def make_node (self , x , y ):
3443
3447
x = as_tensor_variable (x )
3444
3448
y = as_tensor_variable (y )
3445
- if inverse : # as_tensor_variable does not accept booleans
3446
- inverse = as_tensor_variable (1 )
3447
- else :
3448
- inverse = as_tensor_variable (0 )
3449
3449
3450
3450
# y should contain integers
3451
3451
assert y .type .dtype in integer_dtypes
3452
- # Inverse should be an integer scalar
3453
- assert inverse .type .ndim == 0 and inverse .type .dtype in integer_dtypes
3454
3452
3455
3453
# Match shapes of x and y
3456
3454
x_dim = x .type .ndim
@@ -3467,7 +3465,7 @@ def make_node(self, x, y, inverse):
3467
3465
]
3468
3466
out_type = tensor (dtype = x .type .dtype , shape = out_shape )
3469
3467
3470
- inputlist = [x , y , inverse ]
3468
+ inputlist = [x , y ]
3471
3469
outputlist = [out_type ]
3472
3470
return Apply (self , inputlist , outputlist )
3473
3471
@@ -3520,7 +3518,7 @@ def _rec_perform(self, node, x, y, inverse, out, curdim):
3520
3518
raise ValueError (f"Dimension mismatch: { xs0 } , { ys0 } " )
3521
3519
3522
3520
def perform (self , node , inp , out ):
3523
- x , y , inverse = inp
3521
+ x , y = inp
3524
3522
(outs ,) = out
3525
3523
x_s = x .shape
3526
3524
y_s = y .shape
@@ -3543,7 +3541,7 @@ def perform(self, node, inp, out):
3543
3541
if outs [0 ] is None or outs [0 ].shape != out_s :
3544
3542
outs [0 ] = np .empty (out_s , dtype = x .dtype )
3545
3543
3546
- self ._rec_perform (node , x , y , inverse , outs [0 ], curdim = 0 )
3544
+ self ._rec_perform (node , x , y , self . inverse , outs [0 ], curdim = 0 )
3547
3545
3548
3546
def infer_shape (self , fgraph , node , in_shapes ):
3549
3547
from pytensor .tensor .math import maximum
@@ -3555,14 +3553,14 @@ def infer_shape(self, fgraph, node, in_shapes):
3555
3553
return [out_shape ]
3556
3554
3557
3555
def grad (self , inp , grads ):
3558
- from pytensor .tensor .math import Sum , eq
3556
+ from pytensor .tensor .math import Sum
3559
3557
3560
- x , y , inverse = inp
3558
+ x , y = inp
3561
3559
(gz ,) = grads
3562
3560
# First, compute the gradient wrt the broadcasted x.
3563
3561
# If 'inverse' is False (0), apply the inverse of y on gz.
3564
3562
# Else, apply y on gz.
3565
- gx = permute_row_elements (gz , y , eq ( inverse , 0 ) )
3563
+ gx = permute_row_elements (gz , y , not self . inverse )
3566
3564
3567
3565
# If x has been broadcasted along some axes, we need to sum
3568
3566
# the gradient over these axes, but keep the dimension (as
@@ -3599,20 +3597,17 @@ def grad(self, inp, grads):
3599
3597
if x .type .dtype in discrete_dtypes :
3600
3598
gx = x .zeros_like ()
3601
3599
3602
- # The elements of y and of inverse both affect the output,
3600
+ # The elements of y affect the output,
3603
3601
# so they are connected to the output,
3604
3602
# and the transformation isn't defined if their values
3605
3603
# are non-integer, so the gradient with respect to them is
3606
3604
# undefined
3607
3605
3608
- return [gx , grad_undefined (self , 1 , y ), grad_undefined (self , 1 , inverse )]
3609
-
3610
-
3611
- _permute_row_elements = PermuteRowElements ()
3606
+ return [gx , grad_undefined (self , 1 , y )]
3612
3607
3613
3608
3614
- def permute_row_elements (x , y , inverse = 0 ):
3615
- return _permute_row_elements ( x , y , inverse )
3609
+ def permute_row_elements (x , y , inverse = False ):
3610
+ return PermuteRowElements ( inverse = inverse )( x , y )
3616
3611
3617
3612
3618
3613
def inverse_permutation (perm ):
0 commit comments