21
21
import pytensor .scalar .sharedvar
22
22
from pytensor import compile , config , printing
23
23
from pytensor import scalar as ps
24
+ from pytensor .compile .builders import OpFromGraph
24
25
from pytensor .gradient import DisconnectedType , grad_undefined
25
26
from pytensor .graph import RewriteDatabaseQuery
26
27
from pytensor .graph .basic import Apply , Constant , Variable , equal_computations
@@ -1334,6 +1335,25 @@ def infer_shape(self, fgraph, node, in_shapes):
1334
1335
def grad (self , inp , grads ):
1335
1336
return [grad_undefined (self , i , inp [i ]) for i in range (3 )]
1336
1337
1338
+ @staticmethod
1339
+ def is_offset_zero (node ) -> bool :
1340
+ """
1341
+ Test if an Eye Op has a diagonal offset of zero
1342
+
1343
+ Parameters
1344
+ ----------
1345
+ node
1346
+ Eye node to test
1347
+
1348
+ Returns
1349
+ -------
1350
+ is_offset_zero: bool
1351
+ True if the offset is zero (``k = 0``).
1352
+ """
1353
+
1354
+ offset = node .inputs [- 1 ]
1355
+ return isinstance (offset , Constant ) and offset .data .item () == 0
1356
+
1337
1357
1338
1358
def eye (n , m = None , k = 0 , dtype = None ):
1339
1359
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.
@@ -3749,109 +3769,37 @@ def trace(a, offset=0, axis1=0, axis2=1):
3749
3769
return diagonal (a , offset = offset , axis1 = axis1 , axis2 = axis2 ).sum (- 1 )
3750
3770
3751
3771
3752
- class AllocDiag (Op ):
3753
- """An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
3772
+ class AllocDiag (OpFromGraph ):
3773
+ """
3774
+ Wrapper Op for alloc_diag graphs
3775
+ """
3754
3776
3755
- __props__ = ("offset" , " axis1" , "axis2" )
3777
+ __props__ = ("axis1" , "axis2" )
3756
3778
3757
- def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 ):
3758
- """
3759
- Parameters
3760
- ----------
3761
- offset: int
3762
- Offset of the diagonal from the main diagonal defined by `axis1`
3763
- and `axis2`. Can be positive or negative. Defaults to main
3764
- diagonal (i.e. 0).
3765
- axis1: int
3766
- Axis to be used as the first axis of the 2-D sub-arrays to which
3767
- the diagonals will be allocated. Defaults to first axis (i.e. 0).
3768
- axis2: int
3769
- Axis to be used as the second axis of the 2-D sub-arrays to which
3770
- the diagonals will be allocated. Defaults to second axis (i.e. 1).
3771
- """
3772
- warnings .warn (
3773
- "AllocDiag is deprecated. Use `alloc_diag` instead" ,
3774
- FutureWarning ,
3775
- )
3776
- self .offset = offset
3777
- if axis1 < 0 or axis2 < 0 :
3778
- raise NotImplementedError ("AllocDiag does not support negative axis" )
3779
- if axis1 == axis2 :
3780
- raise ValueError ("axis1 and axis2 cannot be the same" )
3779
+ def __init__ (self , * args , axis1 , axis2 , offset , ** kwargs ):
3781
3780
self .axis1 = axis1
3782
3781
self .axis2 = axis2
3782
+ self .offset = offset
3783
3783
3784
- def make_node (self , diag ):
3785
- diag = as_tensor_variable (diag )
3786
- if diag .type .ndim < 1 :
3787
- raise ValueError (
3788
- "AllocDiag needs an input with 1 or more dimensions" , diag .type
3789
- )
3790
- return Apply (
3791
- self ,
3792
- [diag ],
3793
- [diag .type .clone (shape = (None ,) * (diag .ndim + 1 ))()],
3794
- )
3795
-
3796
- def perform (self , node , inputs , outputs ):
3797
- (x ,) = inputs
3798
- (z ,) = outputs
3799
-
3800
- axis1 = np .minimum (self .axis1 , self .axis2 )
3801
- axis2 = np .maximum (self .axis1 , self .axis2 )
3802
- offset = self .offset
3803
-
3804
- # Create array with one extra dimension for resulting matrix
3805
- result_shape = x .shape [:- 1 ] + (x .shape [- 1 ] + abs (offset ),) * 2
3806
- result = np .zeros (result_shape , dtype = x .dtype )
3807
-
3808
- # Create slice for diagonal in final 2 axes
3809
- idxs = np .arange (x .shape [- 1 ])
3810
- diagonal_slice = (len (result_shape ) - 2 ) * [slice (None )] + [
3811
- idxs + np .maximum (0 , - offset ),
3812
- idxs + np .maximum (0 , offset ),
3813
- ]
3814
-
3815
- # Fill in final 2 axes with x
3816
- result [tuple (diagonal_slice )] = x
3817
-
3818
- if len (x .shape ) > 1 :
3819
- # Re-order axes so they correspond to diagonals at axis1, axis2
3820
- axes = list (range (len (x .shape [:- 1 ])))
3821
- last_idx = axes [- 1 ]
3822
- axes = axes [:axis1 ] + [last_idx + 1 ] + axes [axis1 :]
3823
- axes = axes [:axis2 ] + [last_idx + 2 ] + axes [axis2 :]
3824
- result = result .transpose (axes )
3825
-
3826
- z [0 ] = result
3827
-
3828
- def grad (self , inputs , gout ):
3829
- (gz ,) = gout
3830
- return [diagonal (gz , offset = self .offset , axis1 = self .axis1 , axis2 = self .axis2 )]
3831
-
3832
- def infer_shape (self , fgraph , nodes , shapes ):
3833
- (x_shape ,) = shapes
3834
- axis1 = np .minimum (self .axis1 , self .axis2 )
3835
- axis2 = np .maximum (self .axis1 , self .axis2 )
3784
+ super ().__init__ (* args , ** kwargs , strict = True )
3836
3785
3837
- result_shape = list (x_shape [:- 1 ])
3838
- diag_shape = x_shape [- 1 ] + abs (self .offset )
3839
- result_shape = result_shape [:axis1 ] + [diag_shape ] + result_shape [axis1 :]
3840
- result_shape = result_shape [:axis2 ] + [diag_shape ] + result_shape [axis2 :]
3841
- return [tuple (result_shape )]
3786
+ @staticmethod
3787
+ def is_offset_zero (node ) -> bool :
3788
+ """
3789
+ Test if an AllocDiag Op has a diagonal offset of zero
3842
3790
3843
- def __setstate__ (self , state ):
3844
- if "view_map" in state :
3845
- del state ["view_map" ]
3791
+ Parameters
3792
+ ----------
3793
+ node
3794
+ AllocDiag node to test
3846
3795
3847
- self .__dict__ .update (state )
3796
+ Returns
3797
+ -------
3798
+ is_offset_zero: bool
3799
+ True if the offset is zero (``k = 0``).
3800
+ """
3848
3801
3849
- if "offset" not in state :
3850
- self .offset = 0
3851
- if "axis1" not in state :
3852
- self .axis1 = 0
3853
- if "axis2" not in state :
3854
- self .axis2 = 1
3802
+ return node .op .offset == 0
3855
3803
3856
3804
3857
3805
def alloc_diag (diag , offset = 0 , axis1 = 0 , axis2 = 1 ):
@@ -3862,6 +3810,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
3862
3810
from pytensor .tensor import set_subtensor
3863
3811
3864
3812
diag = as_tensor_variable (diag )
3813
+
3865
3814
axis1 , axis2 = normalize_axis_tuple ((axis1 , axis2 ), ndim = diag .type .ndim + 1 )
3866
3815
if axis1 > axis2 :
3867
3816
axis1 , axis2 = axis2 , axis1
@@ -3888,7 +3837,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
3888
3837
axes = axes [:axis2 ] + [last_idx + 2 ] + axes [axis2 :]
3889
3838
result = result .transpose (axes )
3890
3839
3891
- return result
3840
+ return AllocDiag (
3841
+ inputs = [diag ], outputs = [result ], axis1 = axis1 , axis2 = axis2 , offset = offset
3842
+ )(diag )
3892
3843
3893
3844
3894
3845
def diag (v , k = 0 ):
0 commit comments