Skip to content

Commit 438418e

Browse files
committed
Inverse need not be a symbolic input in PermuteRowElements
1 parent c22e79e commit 438418e

File tree

3 files changed

+22
-28
lines changed

3 files changed

+22
-28
lines changed

pytensor/tensor/basic.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3437,20 +3437,18 @@ class PermuteRowElements(Op):
34373437
permutation instead.
34383438
"""
34393439

3440-
__props__ = ()
3440+
__props__ = ("inverse",)
3441+
3442+
def __init__(self, inverse: bool):
3443+
super().__init__()
3444+
self.inverse = inverse
34413445

3442-
def make_node(self, x, y, inverse):
3446+
def make_node(self, x, y):
34433447
x = as_tensor_variable(x)
34443448
y = as_tensor_variable(y)
3445-
if inverse: # as_tensor_variable does not accept booleans
3446-
inverse = as_tensor_variable(1)
3447-
else:
3448-
inverse = as_tensor_variable(0)
34493449

34503450
# y should contain integers
34513451
assert y.type.dtype in integer_dtypes
3452-
# Inverse should be an integer scalar
3453-
assert inverse.type.ndim == 0 and inverse.type.dtype in integer_dtypes
34543452

34553453
# Match shapes of x and y
34563454
x_dim = x.type.ndim
@@ -3467,7 +3465,7 @@ def make_node(self, x, y, inverse):
34673465
]
34683466
out_type = tensor(dtype=x.type.dtype, shape=out_shape)
34693467

3470-
inputlist = [x, y, inverse]
3468+
inputlist = [x, y]
34713469
outputlist = [out_type]
34723470
return Apply(self, inputlist, outputlist)
34733471

@@ -3520,7 +3518,7 @@ def _rec_perform(self, node, x, y, inverse, out, curdim):
35203518
raise ValueError(f"Dimension mismatch: {xs0}, {ys0}")
35213519

35223520
def perform(self, node, inp, out):
3523-
x, y, inverse = inp
3521+
x, y = inp
35243522
(outs,) = out
35253523
x_s = x.shape
35263524
y_s = y.shape
@@ -3543,7 +3541,7 @@ def perform(self, node, inp, out):
35433541
if outs[0] is None or outs[0].shape != out_s:
35443542
outs[0] = np.empty(out_s, dtype=x.dtype)
35453543

3546-
self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
3544+
self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0)
35473545

35483546
def infer_shape(self, fgraph, node, in_shapes):
35493547
from pytensor.tensor.math import maximum
@@ -3555,14 +3553,14 @@ def infer_shape(self, fgraph, node, in_shapes):
35553553
return [out_shape]
35563554

35573555
def grad(self, inp, grads):
3558-
from pytensor.tensor.math import Sum, eq
3556+
from pytensor.tensor.math import Sum
35593557

3560-
x, y, inverse = inp
3558+
x, y = inp
35613559
(gz,) = grads
35623560
# First, compute the gradient wrt the broadcasted x.
35633561
# If 'inverse' is False (0), apply the inverse of y on gz.
35643562
# Else, apply y on gz.
3565-
gx = permute_row_elements(gz, y, eq(inverse, 0))
3563+
gx = permute_row_elements(gz, y, not self.inverse)
35663564

35673565
# If x has been broadcasted along some axes, we need to sum
35683566
# the gradient over these axes, but keep the dimension (as
@@ -3599,20 +3597,17 @@ def grad(self, inp, grads):
35993597
if x.type.dtype in discrete_dtypes:
36003598
gx = x.zeros_like()
36013599

3602-
# The elements of y and of inverse both affect the output,
3600+
# The elements of y affect the output,
36033601
# so they are connected to the output,
36043602
# and the transformation isn't defined if their values
36053603
# are non-integer, so the gradient with respect to them is
36063604
# undefined
36073605

3608-
return [gx, grad_undefined(self, 1, y), grad_undefined(self, 1, inverse)]
3609-
3610-
3611-
_permute_row_elements = PermuteRowElements()
3606+
return [gx, grad_undefined(self, 1, y)]
36123607

36133608

3614-
def permute_row_elements(x, y, inverse=0):
3615-
return _permute_row_elements(x, y, inverse)
3609+
def permute_row_elements(x, y, inverse=False):
3610+
return PermuteRowElements(inverse=inverse)(x, y)
36163611

36173612

36183613
def inverse_permutation(perm):

pytensor/tensor/rewriting/subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
11471147
val = switch(le(len2, 0), len1 + 1, val)
11481148
val = switch(ge(sl2, len2), len1 + 1, val)
11491149
val = switch(lt(sl2, 0), -len1 - 1, val)
1150-
if sl1.step:
1150+
if sl1.step is not None:
11511151
val = switch(eq(sl1.step, 0), len1 + 1, val)
11521152
return val
11531153
else:

tests/tensor/test_basic.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4047,21 +4047,20 @@ def test_PermuteRowElements(self):
40474047
advec = dvector()
40484048
aivec = ivector()
40494049

4050-
abool = True
40514050
rng = np.random.default_rng(utt.fetch_seed())
40524051
advec_val = random(5)
40534052
aivec_val = rng.permutation(5).astype("int32")
40544053
self._compile_and_check(
40554054
[advec, aivec],
4056-
[PermuteRowElements()(advec, aivec, abool)],
4055+
[PermuteRowElements(inverse=True)(advec, aivec)],
40574056
[advec_val, aivec_val],
40584057
PermuteRowElements,
40594058
)
40604059

40614060
admat_val = random(3, 5)
40624061
self._compile_and_check(
40634062
[admat, aivec],
4064-
[PermuteRowElements()(admat, aivec, abool)],
4063+
[PermuteRowElements(inverse=False)(admat, aivec)],
40654064
[admat_val, aivec_val],
40664065
PermuteRowElements,
40674066
)
@@ -4070,7 +4069,7 @@ def test_PermuteRowElements(self):
40704069
adtens3_val = random(3, 2, 5)
40714070
self._compile_and_check(
40724071
[adtens3, aivec],
4073-
[PermuteRowElements()(adtens3, aivec, abool)],
4072+
[PermuteRowElements(inverse=True)(adtens3, aivec)],
40744073
[adtens3_val, aivec_val],
40754074
PermuteRowElements,
40764075
)
@@ -4083,7 +4082,7 @@ def test_PermuteRowElements(self):
40834082
admat_val = random(3, 5)
40844083
self._compile_and_check(
40854084
[admat, aimat],
4086-
[PermuteRowElements()(admat, aimat, abool)],
4085+
[PermuteRowElements(inverse=False)(admat, aimat)],
40874086
[admat_val, aimat_val],
40884087
PermuteRowElements,
40894088
)
@@ -4098,7 +4097,7 @@ def test_PermuteRowElements(self):
40984097
aitens3_val[1, ::, ::] = bimat_val
40994098
self._compile_and_check(
41004099
[admat, aitens3],
4101-
[PermuteRowElements()(admat, aitens3, abool)],
4100+
[PermuteRowElements(inverse=True)(admat, aitens3)],
41024101
[admat_val, aitens3_val],
41034102
PermuteRowElements,
41044103
)

0 commit comments

Comments
 (0)