@@ -3727,112 +3727,7 @@ def trace(a, offset=0, axis1=0, axis2=1):
3727
3727
return diagonal (a , offset = offset , axis1 = axis1 , axis2 = axis2 ).sum (- 1 )
3728
3728
3729
3729
3730
- class AllocDiag (Op ):
3731
- """An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
3732
-
3733
- __props__ = ("offset" , "axis1" , "axis2" )
3734
-
3735
- def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 ):
3736
- """
3737
- Parameters
3738
- ----------
3739
- offset: int
3740
- Offset of the diagonal from the main diagonal defined by `axis1`
3741
- and `axis2`. Can be positive or negative. Defaults to main
3742
- diagonal (i.e. 0).
3743
- axis1: int
3744
- Axis to be used as the first axis of the 2-D sub-arrays to which
3745
- the diagonals will be allocated. Defaults to first axis (i.e. 0).
3746
- axis2: int
3747
- Axis to be used as the second axis of the 2-D sub-arrays to which
3748
- the diagonals will be allocated. Defaults to second axis (i.e. 1).
3749
- """
3750
- warnings .warn (
3751
- "AllocDiag is deprecated. Use `alloc_diag` instead" ,
3752
- FutureWarning ,
3753
- )
3754
- self .offset = offset
3755
- if axis1 < 0 or axis2 < 0 :
3756
- raise NotImplementedError ("AllocDiag does not support negative axis" )
3757
- if axis1 == axis2 :
3758
- raise ValueError ("axis1 and axis2 cannot be the same" )
3759
- self .axis1 = axis1
3760
- self .axis2 = axis2
3761
-
3762
- def make_node (self , diag ):
3763
- diag = as_tensor_variable (diag )
3764
- if diag .type .ndim < 1 :
3765
- raise ValueError (
3766
- "AllocDiag needs an input with 1 or more dimensions" , diag .type
3767
- )
3768
- return Apply (
3769
- self ,
3770
- [diag ],
3771
- [diag .type .clone (shape = (None ,) * (diag .ndim + 1 ))()],
3772
- )
3773
-
3774
- def perform (self , node , inputs , outputs ):
3775
- (x ,) = inputs
3776
- (z ,) = outputs
3777
-
3778
- axis1 = np .minimum (self .axis1 , self .axis2 )
3779
- axis2 = np .maximum (self .axis1 , self .axis2 )
3780
- offset = self .offset
3781
-
3782
- # Create array with one extra dimension for resulting matrix
3783
- result_shape = x .shape [:- 1 ] + (x .shape [- 1 ] + abs (offset ),) * 2
3784
- result = np .zeros (result_shape , dtype = x .dtype )
3785
-
3786
- # Create slice for diagonal in final 2 axes
3787
- idxs = np .arange (x .shape [- 1 ])
3788
- diagonal_slice = (len (result_shape ) - 2 ) * [slice (None )] + [
3789
- idxs + np .maximum (0 , - offset ),
3790
- idxs + np .maximum (0 , offset ),
3791
- ]
3792
-
3793
- # Fill in final 2 axes with x
3794
- result [tuple (diagonal_slice )] = x
3795
-
3796
- if len (x .shape ) > 1 :
3797
- # Re-order axes so they correspond to diagonals at axis1, axis2
3798
- axes = list (range (len (x .shape [:- 1 ])))
3799
- last_idx = axes [- 1 ]
3800
- axes = axes [:axis1 ] + [last_idx + 1 ] + axes [axis1 :]
3801
- axes = axes [:axis2 ] + [last_idx + 2 ] + axes [axis2 :]
3802
- result = result .transpose (axes )
3803
-
3804
- z [0 ] = result
3805
-
3806
- def grad (self , inputs , gout ):
3807
- (gz ,) = gout
3808
- return [diagonal (gz , offset = self .offset , axis1 = self .axis1 , axis2 = self .axis2 )]
3809
-
3810
- def infer_shape (self , fgraph , nodes , shapes ):
3811
- (x_shape ,) = shapes
3812
- axis1 = np .minimum (self .axis1 , self .axis2 )
3813
- axis2 = np .maximum (self .axis1 , self .axis2 )
3814
-
3815
- result_shape = list (x_shape [:- 1 ])
3816
- diag_shape = x_shape [- 1 ] + abs (self .offset )
3817
- result_shape = result_shape [:axis1 ] + [diag_shape ] + result_shape [axis1 :]
3818
- result_shape = result_shape [:axis2 ] + [diag_shape ] + result_shape [axis2 :]
3819
- return [tuple (result_shape )]
3820
-
3821
- def __setstate__ (self , state ):
3822
- if "view_map" in state :
3823
- del state ["view_map" ]
3824
-
3825
- self .__dict__ .update (state )
3826
-
3827
- if "offset" not in state :
3828
- self .offset = 0
3829
- if "axis1" not in state :
3830
- self .axis1 = 0
3831
- if "axis2" not in state :
3832
- self .axis2 = 1
3833
-
3834
-
3835
- class AllocDiag2 (OpFromGraph ):
3730
+ class AllocDiag (OpFromGraph ):
3836
3731
"""
3837
3732
Wrapper Op for alloc_diag graphs
3838
3733
"""
@@ -3883,7 +3778,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
3883
3778
axes = axes [:axis2 ] + [last_idx + 2 ] + axes [axis2 :]
3884
3779
result = result .transpose (axes )
3885
3780
3886
- return AllocDiag2 (
3781
+ return AllocDiag (
3887
3782
inputs = [diag ], outputs = [result ], offset = offset , axis1 = axis1 , axis2 = axis2
3888
3783
)(diag )
3889
3784
0 commit comments