Skip to content

Commit ffdde1c

Browse files
committed
Implement gradient for vector repetitions
Also cleans up implementation and documentation
1 parent da4960b commit ffdde1c

File tree

2 files changed

+131
-73
lines changed

2 files changed

+131
-73
lines changed

pytensor/tensor/extra_ops.py

+110-66
Original file line numberDiff line numberDiff line change
@@ -646,12 +646,17 @@ class Repeat(Op):
646646

647647
__props__ = ("axis",)
648648

649-
def __init__(self, axis=None):
649+
def __init__(self, axis: int | None = None):
650+
if axis is not None:
651+
if not isinstance(axis, int) or axis < 0:
652+
raise ValueError(
653+
f"Repeat only accepts positive integer axis or None, got {axis}"
654+
)
650655
self.axis = axis
651656

652657
def make_node(self, x, repeats):
653658
x = ptb.as_tensor_variable(x)
654-
repeats = ptb.as_tensor_variable(repeats)
659+
repeats = ptb.as_tensor_variable(repeats, dtype="int64")
655660

656661
if repeats.dtype not in integer_dtypes:
657662
raise TypeError("repeats.dtype must be an integer.")
@@ -687,58 +692,64 @@ def make_node(self, x, repeats):
687692
out_shape = list(x.type.shape)
688693
out_shape[self.axis] = None
689694

690-
out_type = TensorType(
691-
x.dtype, shape=tuple(1 if s == 1 else None for s in out_shape)
692-
)
693-
695+
out_type = TensorType(x.dtype, shape=out_shape)
694696
return Apply(self, [x, repeats], [out_type()])
695697

696698
def perform(self, node, inputs, output_storage):
697-
x = inputs[0]
698-
repeats = inputs[1]
699-
z = output_storage[0]
700-
z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
699+
[x, repeats] = inputs
700+
output_storage[0][0] = np.repeat(x, repeats=repeats, axis=self.axis)
701701

702702
def connection_pattern(self, node):
703703
return [[True], [False]]
704704

705705
def grad(self, inputs, gout):
706706
(x, repeats) = inputs
707707
(gz,) = gout
708+
axis = self.axis
708709
if repeats.ndim == 0:
709-
if self.axis is None:
710-
axis = x.ndim
711-
else:
712-
if self.axis >= 0:
713-
axis = self.axis + 1
714-
else:
715-
axis = self.axis + x.ndim + 1
716-
717-
shape = [x.shape[k] for k in range(x.ndim)]
718-
shape.insert(axis, repeats)
710+
# When axis is a scalar (same number of reps for all elements),
711+
# We can split the repetitions into their own axis with reshape and sum them back
712+
# to the original element location
713+
sum_axis = x.ndim if axis is None else axis + 1
714+
shape = list(x.shape)
715+
shape.insert(sum_axis, repeats)
716+
gx = gz.reshape(shape).sum(axis=sum_axis)
719717

720-
return [
721-
gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis),
722-
DisconnectedType()(),
723-
]
724718
elif repeats.ndim == 1:
725-
# For this implementation, we would need to specify the length
726-
# of repeats in order to split gz in the right way to sum
727-
# the good part.
728-
raise NotImplementedError()
719+
# To sum the gradients that belong to the same repeated x,
720+
# We create a repeated eye and dot product it with the gradient.
721+
axis_size = x.size if axis is None else x.shape[axis]
722+
repeated_eye = repeat(
723+
ptb.eye(axis_size), repeats, axis=0
724+
) # A sparse repeat would be neat
725+
726+
if axis is None:
727+
gx = gz @ repeated_eye
728+
# Undo the ravelling when axis=None
729+
gx = gx.reshape(x.shape)
730+
else:
731+
# Place gradient axis at end for dot product
732+
gx = ptb.moveaxis(gz, axis, -1)
733+
gx = gx @ repeated_eye
734+
# Place gradient back into the correct axis
735+
gx = ptb.moveaxis(gx, -1, axis)
736+
729737
else:
730738
raise ValueError()
731739

740+
return [gx, DisconnectedType()()]
741+
732742
def infer_shape(self, fgraph, node, ins_shapes):
733743
i0_shapes = ins_shapes[0]
734744
repeats = node.inputs[1]
735745
out_shape = list(i0_shapes)
746+
axis = self.axis
736747

737748
# uint64 shape are not supported.
738749
dtype = None
739750
if repeats.dtype in ("uint8", "uint16", "uint32"):
740751
dtype = "int64"
741-
if self.axis is None:
752+
if axis is None:
742753
if repeats.ndim == 0:
743754
if len(i0_shapes) == 0:
744755
out_shape = [repeats]
@@ -751,82 +762,115 @@ def infer_shape(self, fgraph, node, ins_shapes):
751762
out_shape = [pt_sum(repeats, dtype=dtype)]
752763
else:
753764
if repeats.ndim == 0:
754-
out_shape[self.axis] = out_shape[self.axis] * repeats
765+
out_shape[axis] = out_shape[axis] * repeats
755766
else:
756-
out_shape[self.axis] = pt_sum(repeats, dtype=dtype)
767+
out_shape[axis] = pt_sum(repeats, dtype=dtype)
757768
return [out_shape]
758769

759770

760-
def repeat(x, repeats, axis=None):
761-
"""Repeat elements of an array.
771+
def repeat(
772+
a: TensorLike, repeats: TensorLike, axis: int or None = None
773+
) -> TensorVariable:
774+
"""Repeat elements of a tensor.
762775
763-
It returns an array which has the same shape as `x`, except along the given
764-
`axis`. The `axis` parameter is used to specify the axis along which values
765-
are repeated. By default, a flattened version of `x` is used.
776+
See :func:`numpy.repeat` for more information.
766777
767-
The number of repetitions for each element is `repeats`. `repeats` is
768-
broadcasted to fit the length of the given `axis`.
769778
770779
Parameters
771780
----------
772-
x
773-
Input data, tensor variable.
774-
repeats
775-
int, scalar or tensor variable
781+
a: tensor_like
782+
Input tensor
783+
repeats: tensor_like
784+
The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
776785
axis : int, optional
786+
The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
777787
778-
See Also
788+
Returns
789+
-------
790+
repeated_tensor: TensorVariable
791+
Output tensor which as the same shape as a, except along the given axis
792+
793+
Examples
779794
--------
780-
tensor.tile
795+
796+
.. testcode::
797+
798+
import pytensor.tensor as pt
799+
800+
a = pt.arange(4).reshape((2, 2))
801+
out = pt.repeat(a, repeats=[2, 3], axis=0)
802+
print(out.eval())
803+
804+
.. testoutput::
805+
806+
[[0 1]
807+
[0 1]
808+
[2 3]
809+
[2 3]
810+
[2 3]]
811+
812+
When axis is None, the array is first flattened and then repeated
813+
814+
.. testcode::
815+
816+
import pytensor.tensor as pt
817+
818+
a = pt.arange(4).reshape((2, 2))
819+
out = pt.repeat(a, repeats=[2, 3, 0, 1], axis=None)
820+
print(out.eval())
821+
822+
.. testoutput::
823+
824+
[0 0 1 1 1 3]
825+
781826
782827
.. versionadded:: 0.6
783828
784829
"""
830+
a = ptb.as_tensor_variable(a)
831+
832+
if axis is not None:
833+
axis = normalize_axis_index(axis, a.ndim)
834+
785835
repeats = ptb.as_tensor_variable(repeats, dtype=np.int64)
786836

787837
if repeats.ndim > 1:
788838
raise ValueError("The dimension of repeats should not exceed 1.")
789839

790840
if repeats.ndim == 1 and not repeats.broadcastable[0]:
791-
return Repeat(axis=axis)(x, repeats)
841+
# We only use the Repeat Op for vector repeats
842+
return Repeat(axis=axis)(a, repeats)
792843
else:
793844
if repeats.ndim == 1:
794845
repeats = repeats[0]
795846

796-
if x.dtype == "uint64":
847+
if a.dtype == "uint64":
848+
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
849+
# Which is not valid for the `reshape` operation at the end
797850
raise TypeError("repeat doesn't support dtype uint64")
798851

799852
if axis is None:
800853
axis = 0
801-
x = x.flatten()
802-
else:
803-
if axis >= x.ndim:
804-
raise ValueError("Axis should not exceed x.ndim-1.")
805-
if axis < 0:
806-
axis = x.ndim + axis
854+
a = a.flatten()
807855

808-
shape = [x.shape[i] for i in range(x.ndim)]
856+
repeat_shape = list(a.shape)
809857

810-
# shape_ is the shape of the intermediate tensor which has
858+
# alloc_shape is the shape of the intermediate tensor which has
811859
# an additional dimension comparing to x. We use alloc to
812860
# allocate space for this intermediate tensor to replicate x
813861
# along that additional dimension.
814-
shape_ = shape[:]
815-
shape_.insert(axis + 1, repeats)
862+
alloc_shape = repeat_shape[:]
863+
alloc_shape.insert(axis + 1, repeats)
816864

817-
# shape is now the shape of output, where shape[axis] becomes
865+
# repeat_shape is now the shape of output, where shape[axis] becomes
818866
# shape[axis]*repeats.
819-
shape[axis] = shape[axis] * repeats
820-
821-
# dims_ is the dimension of that intermediate tensor.
822-
dims_ = list(np.arange(x.ndim))
823-
dims_.insert(axis + 1, "x")
867+
repeat_shape[axis] = repeat_shape[axis] * repeats
824868

825869
# After the original tensor is duplicated along the additional
826-
# dimension, we reshape it to the expected output shape, and
827-
# return the output z.
828-
z = ptb.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape)
829-
return z
870+
# dimension, we reshape it to the expected output shape
871+
return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape(
872+
repeat_shape
873+
)
830874

831875

832876
class Bartlett(Op):

tests/tensor/test_extra_ops.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,6 @@ def test_basic(self, ndim, dtype):
595595
isinstance(n.op, Repeat) for n in f.maker.fgraph.toposort()
596596
)
597597

598-
@pytest.mark.slow
599598
@pytest.mark.parametrize("ndim", [1, 3])
600599
@pytest.mark.parametrize("dtype", ["int8", "uint8", "uint64"])
601600
def test_infer_shape(self, ndim, dtype):
@@ -606,6 +605,10 @@ def test_infer_shape(self, ndim, dtype):
606605
a = rng.random(shp).astype(config.floatX)
607606

608607
for axis in self._possible_axis(ndim):
608+
if axis is not None and axis < 0:
609+
# Operator does not support negative axis
610+
continue
611+
609612
r_var = scalar(dtype=dtype)
610613
r = np.asarray(3, dtype=dtype)
611614
if dtype in self.numpy_unsupported_dtypes:
@@ -635,12 +638,23 @@ def test_infer_shape(self, ndim, dtype):
635638
self.op_class,
636639
)
637640

638-
@pytest.mark.parametrize("ndim", range(3))
639-
def test_grad(self, ndim):
640-
a = np.random.random((10,) * ndim).astype(config.floatX)
641-
642-
for axis in self._possible_axis(ndim):
643-
utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a])
641+
@pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}")
642+
@pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}")
643+
@pytest.mark.parametrize("axis", [None, 0, 1], ids=lambda a: f"axis={a}")
644+
def test_grad(self, x_ndim, repeats_ndim, axis):
645+
rng = np.random.default_rng(
646+
[653, x_ndim, 2 if axis is None else axis, repeats_ndim]
647+
)
648+
x_test = rng.normal(size=np.arange(3, 3 + x_ndim))
649+
if repeats_ndim == 0:
650+
repeats_size = ()
651+
else:
652+
repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,)
653+
repeats = rng.integers(1, 6, size=repeats_size)
654+
utt.verify_grad(
655+
lambda x: Repeat(axis=axis)(x, repeats),
656+
[x_test],
657+
)
644658

645659
def test_broadcastable(self):
646660
x = TensorType(config.floatX, shape=(None, 1, None))()

0 commit comments

Comments
 (0)