@@ -646,12 +646,17 @@ class Repeat(Op):
646
646
647
647
__props__ = ("axis" ,)
648
648
649
- def __init__ (self , axis = None ):
649
+ def __init__ (self , axis : int | None = None ):
650
+ if axis is not None :
651
+ if not isinstance (axis , int ) or axis < 0 :
652
+ raise ValueError (
653
+ f"Repeat only accepts positive integer axis or None, got { axis } "
654
+ )
650
655
self .axis = axis
651
656
652
657
def make_node (self , x , repeats ):
653
658
x = ptb .as_tensor_variable (x )
654
- repeats = ptb .as_tensor_variable (repeats )
659
+ repeats = ptb .as_tensor_variable (repeats , dtype = "int64" )
655
660
656
661
if repeats .dtype not in integer_dtypes :
657
662
raise TypeError ("repeats.dtype must be an integer." )
@@ -687,58 +692,64 @@ def make_node(self, x, repeats):
687
692
out_shape = list (x .type .shape )
688
693
out_shape [self .axis ] = None
689
694
690
- out_type = TensorType (
691
- x .dtype , shape = tuple (1 if s == 1 else None for s in out_shape )
692
- )
693
-
695
+ out_type = TensorType (x .dtype , shape = out_shape )
694
696
return Apply (self , [x , repeats ], [out_type ()])
695
697
696
698
def perform (self , node , inputs , output_storage ):
697
- x = inputs [0 ]
698
- repeats = inputs [1 ]
699
- z = output_storage [0 ]
700
- z [0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
699
+ [x , repeats ] = inputs
700
+ output_storage [0 ][0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
701
701
702
702
def connection_pattern (self , node ):
703
703
return [[True ], [False ]]
704
704
705
705
def grad (self , inputs , gout ):
706
706
(x , repeats ) = inputs
707
707
(gz ,) = gout
708
+ axis = self .axis
708
709
if repeats .ndim == 0 :
709
- if self .axis is None :
710
- axis = x .ndim
711
- else :
712
- if self .axis >= 0 :
713
- axis = self .axis + 1
714
- else :
715
- axis = self .axis + x .ndim + 1
716
-
717
- shape = [x .shape [k ] for k in range (x .ndim )]
718
- shape .insert (axis , repeats )
710
+ # When axis is a scalar (same number of reps for all elements),
711
+ # We can split the repetitions into their own axis with reshape and sum them back
712
+ # to the original element location
713
+ sum_axis = x .ndim if axis is None else axis + 1
714
+ shape = list (x .shape )
715
+ shape .insert (sum_axis , repeats )
716
+ gx = gz .reshape (shape ).sum (axis = sum_axis )
719
717
720
- return [
721
- gz .reshape (shape , ndim = x .ndim + 1 ).sum (axis = axis ),
722
- DisconnectedType ()(),
723
- ]
724
718
elif repeats .ndim == 1 :
725
- # For this implementation, we would need to specify the length
726
- # of repeats in order to split gz in the right way to sum
727
- # the good part.
728
- raise NotImplementedError ()
719
+ # To sum the gradients that belong to the same repeated x,
720
+ # We create a repeated eye and dot product it with the gradient.
721
+ axis_size = x .size if axis is None else x .shape [axis ]
722
+ repeated_eye = repeat (
723
+ ptb .eye (axis_size ), repeats , axis = 0
724
+ ) # A sparse repeat would be neat
725
+
726
+ if axis is None :
727
+ gx = gz @ repeated_eye
728
+ # Undo the ravelling when axis=None
729
+ gx = gx .reshape (x .shape )
730
+ else :
731
+ # Place gradient axis at end for dot product
732
+ gx = ptb .moveaxis (gz , axis , - 1 )
733
+ gx = gx @ repeated_eye
734
+ # Place gradient back into the correct axis
735
+ gx = ptb .moveaxis (gx , - 1 , axis )
736
+
729
737
else :
730
738
raise ValueError ()
731
739
740
+ return [gx , DisconnectedType ()()]
741
+
732
742
def infer_shape (self , fgraph , node , ins_shapes ):
733
743
i0_shapes = ins_shapes [0 ]
734
744
repeats = node .inputs [1 ]
735
745
out_shape = list (i0_shapes )
746
+ axis = self .axis
736
747
737
748
# uint64 shape are not supported.
738
749
dtype = None
739
750
if repeats .dtype in ("uint8" , "uint16" , "uint32" ):
740
751
dtype = "int64"
741
- if self . axis is None :
752
+ if axis is None :
742
753
if repeats .ndim == 0 :
743
754
if len (i0_shapes ) == 0 :
744
755
out_shape = [repeats ]
@@ -751,82 +762,115 @@ def infer_shape(self, fgraph, node, ins_shapes):
751
762
out_shape = [pt_sum (repeats , dtype = dtype )]
752
763
else :
753
764
if repeats .ndim == 0 :
754
- out_shape [self . axis ] = out_shape [self . axis ] * repeats
765
+ out_shape [axis ] = out_shape [axis ] * repeats
755
766
else :
756
- out_shape [self . axis ] = pt_sum (repeats , dtype = dtype )
767
+ out_shape [axis ] = pt_sum (repeats , dtype = dtype )
757
768
return [out_shape ]
758
769
759
770
760
- def repeat (x , repeats , axis = None ):
761
- """Repeat elements of an array.
771
+ def repeat (
772
+ a : TensorLike , repeats : TensorLike , axis : int or None = None
773
+ ) -> TensorVariable :
774
+ """Repeat elements of a tensor.
762
775
763
- It returns an array which has the same shape as `x`, except along the given
764
- `axis`. The `axis` parameter is used to specify the axis along which values
765
- are repeated. By default, a flattened version of `x` is used.
776
+ See :func:`numpy.repeat` for more information.
766
777
767
- The number of repetitions for each element is `repeats`. `repeats` is
768
- broadcasted to fit the length of the given `axis`.
769
778
770
779
Parameters
771
780
----------
772
- x
773
- Input data, tensor variable.
774
- repeats
775
- int, scalar or tensor variable
781
+ a: tensor_like
782
+ Input tensor
783
+ repeats: tensor_like
784
+ The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
776
785
axis : int, optional
786
+ The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
777
787
778
- See Also
788
+ Returns
789
+ -------
790
+ repeated_tensor: TensorVariable
791
+ Output tensor which as the same shape as a, except along the given axis
792
+
793
+ Examples
779
794
--------
780
- tensor.tile
795
+
796
+ .. testcode::
797
+
798
+ import pytensor.tensor as pt
799
+
800
+ a = pt.arange(4).reshape((2, 2))
801
+ out = pt.repeat(a, repeats=[2, 3], axis=0)
802
+ print(out.eval())
803
+
804
+ .. testoutput::
805
+
806
+ [[0 1]
807
+ [0 1]
808
+ [2 3]
809
+ [2 3]
810
+ [2 3]]
811
+
812
+ When axis is None, the array is first flattened and then repeated
813
+
814
+ .. testcode::
815
+
816
+ import pytensor.tensor as pt
817
+
818
+ a = pt.arange(4).reshape((2, 2))
819
+ out = pt.repeat(a, repeats=[2, 3, 0, 1], axis=None)
820
+ print(out.eval())
821
+
822
+ .. testoutput::
823
+
824
+ [0 0 1 1 1 3]
825
+
781
826
782
827
.. versionadded:: 0.6
783
828
784
829
"""
830
+ a = ptb .as_tensor_variable (a )
831
+
832
+ if axis is not None :
833
+ axis = normalize_axis_index (axis , a .ndim )
834
+
785
835
repeats = ptb .as_tensor_variable (repeats , dtype = np .int64 )
786
836
787
837
if repeats .ndim > 1 :
788
838
raise ValueError ("The dimension of repeats should not exceed 1." )
789
839
790
840
if repeats .ndim == 1 and not repeats .broadcastable [0 ]:
791
- return Repeat (axis = axis )(x , repeats )
841
+ # We only use the Repeat Op for vector repeats
842
+ return Repeat (axis = axis )(a , repeats )
792
843
else :
793
844
if repeats .ndim == 1 :
794
845
repeats = repeats [0 ]
795
846
796
- if x .dtype == "uint64" :
847
+ if a .dtype == "uint64" :
848
+ # Multiplying int64 (shape) by uint64 (repeats) yields a float64
849
+ # Which is not valid for the `reshape` operation at the end
797
850
raise TypeError ("repeat doesn't support dtype uint64" )
798
851
799
852
if axis is None :
800
853
axis = 0
801
- x = x .flatten ()
802
- else :
803
- if axis >= x .ndim :
804
- raise ValueError ("Axis should not exceed x.ndim-1." )
805
- if axis < 0 :
806
- axis = x .ndim + axis
854
+ a = a .flatten ()
807
855
808
- shape = [ x . shape [ i ] for i in range ( x . ndim )]
856
+ repeat_shape = list ( a . shape )
809
857
810
- # shape_ is the shape of the intermediate tensor which has
858
+ # alloc_shape is the shape of the intermediate tensor which has
811
859
# an additional dimension comparing to x. We use alloc to
812
860
# allocate space for this intermediate tensor to replicate x
813
861
# along that additional dimension.
814
- shape_ = shape [:]
815
- shape_ .insert (axis + 1 , repeats )
862
+ alloc_shape = repeat_shape [:]
863
+ alloc_shape .insert (axis + 1 , repeats )
816
864
817
- # shape is now the shape of output, where shape[axis] becomes
865
+ # repeat_shape is now the shape of output, where shape[axis] becomes
818
866
# shape[axis]*repeats.
819
- shape [axis ] = shape [axis ] * repeats
820
-
821
- # dims_ is the dimension of that intermediate tensor.
822
- dims_ = list (np .arange (x .ndim ))
823
- dims_ .insert (axis + 1 , "x" )
867
+ repeat_shape [axis ] = repeat_shape [axis ] * repeats
824
868
825
869
# After the original tensor is duplicated along the additional
826
- # dimension, we reshape it to the expected output shape, and
827
- # return the output z.
828
- z = ptb . alloc ( x . dimshuffle ( * dims_ ), * shape_ ). reshape ( shape )
829
- return z
870
+ # dimension, we reshape it to the expected output shape
871
+ return ptb . alloc ( ptb . expand_dims ( a , axis + 1 ), * alloc_shape ). reshape (
872
+ repeat_shape
873
+ )
830
874
831
875
832
876
class Bartlett (Op ):
0 commit comments